std: add expectEqualDeep (#13995)
This commit is contained in:
@@ -670,6 +670,252 @@ pub fn expectStringEndsWith(actual: []const u8, expected_ends_with: []const u8)
|
||||
return error.TestExpectedEndsWith;
|
||||
}
|
||||
|
||||
/// This function is intended to be used only in tests. When the two values are not
|
||||
/// deeply equal, prints diagnostics to stderr to show exactly how they are not equal,
|
||||
/// then returns a test failure error.
|
||||
/// `actual` is casted to the type of `expected`.
|
||||
///
|
||||
/// Deeply equal is defined as follows:
|
||||
/// Primitive types are deeply equal if they are equal using `==` operator.
|
||||
/// Struct values are deeply equal if their corresponding fields are deeply equal.
|
||||
/// Container types(like Array/Slice/Vector) deeply equal when their corresponding elements are deeply equal.
|
||||
/// Pointer values are deeply equal if values they point to are deeply equal.
|
||||
///
|
||||
/// Note: Self-referential structs are not supported (e.g. things like std.SinglyLinkedList)
|
||||
pub fn expectEqualDeep(expected: anytype, actual: @TypeOf(expected)) !void {
|
||||
switch (@typeInfo(@TypeOf(actual))) {
|
||||
.NoReturn,
|
||||
.Opaque,
|
||||
.Frame,
|
||||
.AnyFrame,
|
||||
=> @compileError("value of type " ++ @typeName(@TypeOf(actual)) ++ " encountered"),
|
||||
|
||||
.Undefined,
|
||||
.Null,
|
||||
.Void,
|
||||
=> return,
|
||||
|
||||
.Type => {
|
||||
if (actual != expected) {
|
||||
std.debug.print("expected type {s}, found type {s}\n", .{ @typeName(expected), @typeName(actual) });
|
||||
return error.TestExpectedEqual;
|
||||
}
|
||||
},
|
||||
|
||||
.Bool,
|
||||
.Int,
|
||||
.Float,
|
||||
.ComptimeFloat,
|
||||
.ComptimeInt,
|
||||
.EnumLiteral,
|
||||
.Enum,
|
||||
.Fn,
|
||||
.ErrorSet,
|
||||
=> {
|
||||
if (actual != expected) {
|
||||
std.debug.print("expected {}, found {}\n", .{ expected, actual });
|
||||
return error.TestExpectedEqual;
|
||||
}
|
||||
},
|
||||
|
||||
.Pointer => |pointer| {
|
||||
switch (pointer.size) {
|
||||
// We have no idea what is behind those pointers, so the best we can do is `==` check.
|
||||
.C, .Many => {
|
||||
if (actual != expected) {
|
||||
std.debug.print("expected {*}, found {*}\n", .{ expected, actual });
|
||||
return error.TestExpectedEqual;
|
||||
}
|
||||
},
|
||||
.One => {
|
||||
// Length of those pointers are runtime value, so the best we can do is `==` check.
|
||||
switch (@typeInfo(pointer.child)) {
|
||||
.Fn, .Opaque => {
|
||||
if (actual != expected) {
|
||||
std.debug.print("expected {*}, found {*}\n", .{ expected, actual });
|
||||
return error.TestExpectedEqual;
|
||||
}
|
||||
},
|
||||
else => try expectEqualDeep(expected.*, actual.*),
|
||||
}
|
||||
},
|
||||
.Slice => {
|
||||
if (expected.len != actual.len) {
|
||||
std.debug.print("Slice len not the same, expected {d}, found {d}\n", .{ expected.len, actual.len });
|
||||
return error.TestExpectedEqual;
|
||||
}
|
||||
var i: usize = 0;
|
||||
while (i < expected.len) : (i += 1) {
|
||||
expectEqualDeep(expected[i], actual[i]) catch |e| {
|
||||
std.debug.print("index {d} incorrect. expected {any}, found {any}\n", .{
|
||||
i, expected[i], actual[i],
|
||||
});
|
||||
return e;
|
||||
};
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
|
||||
.Array => |_| {
|
||||
if (expected.len != actual.len) {
|
||||
std.debug.print("Array len not the same, expected {d}, found {d}\n", .{ expected.len, actual.len });
|
||||
return error.TestExpectedEqual;
|
||||
}
|
||||
var i: usize = 0;
|
||||
while (i < expected.len) : (i += 1) {
|
||||
expectEqualDeep(expected[i], actual[i]) catch |e| {
|
||||
std.debug.print("index {d} incorrect. expected {any}, found {any}\n", .{
|
||||
i, expected[i], actual[i],
|
||||
});
|
||||
return e;
|
||||
};
|
||||
}
|
||||
},
|
||||
|
||||
.Vector => |info| {
|
||||
if (info.len != @typeInfo(@TypeOf(actual)).Vector.len) {
|
||||
std.debug.print("Vector len not the same, expected {d}, found {d}\n", .{ info.len, @typeInfo(@TypeOf(actual)).Vector.len });
|
||||
return error.TestExpectedEqual;
|
||||
}
|
||||
var i: usize = 0;
|
||||
while (i < info.len) : (i += 1) {
|
||||
expectEqualDeep(expected[i], actual[i]) catch |e| {
|
||||
std.debug.print("index {d} incorrect. expected {any}, found {any}\n", .{
|
||||
i, expected[i], actual[i],
|
||||
});
|
||||
return e;
|
||||
};
|
||||
}
|
||||
},
|
||||
|
||||
.Struct => |structType| {
|
||||
inline for (structType.fields) |field| {
|
||||
expectEqualDeep(@field(expected, field.name), @field(actual, field.name)) catch |e| {
|
||||
std.debug.print("Field {s} incorrect. expected {any}, found {any}\n", .{ field.name, @field(expected, field.name), @field(actual, field.name) });
|
||||
return e;
|
||||
};
|
||||
}
|
||||
},
|
||||
|
||||
.Union => |union_info| {
|
||||
if (union_info.tag_type == null) {
|
||||
@compileError("Unable to compare untagged union values");
|
||||
}
|
||||
|
||||
const Tag = std.meta.Tag(@TypeOf(expected));
|
||||
|
||||
const expectedTag = @as(Tag, expected);
|
||||
const actualTag = @as(Tag, actual);
|
||||
|
||||
try expectEqual(expectedTag, actualTag);
|
||||
|
||||
// we only reach this loop if the tags are equal
|
||||
switch (expected) {
|
||||
inline else => |val, tag| {
|
||||
try expectEqualDeep(val, @field(actual, @tagName(tag)));
|
||||
},
|
||||
}
|
||||
},
|
||||
|
||||
.Optional => {
|
||||
if (expected) |expected_payload| {
|
||||
if (actual) |actual_payload| {
|
||||
try expectEqualDeep(expected_payload, actual_payload);
|
||||
} else {
|
||||
std.debug.print("expected {any}, found null\n", .{expected_payload});
|
||||
return error.TestExpectedEqual;
|
||||
}
|
||||
} else {
|
||||
if (actual) |actual_payload| {
|
||||
std.debug.print("expected null, found {any}\n", .{actual_payload});
|
||||
return error.TestExpectedEqual;
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
.ErrorUnion => {
|
||||
if (expected) |expected_payload| {
|
||||
if (actual) |actual_payload| {
|
||||
try expectEqualDeep(expected_payload, actual_payload);
|
||||
} else |actual_err| {
|
||||
std.debug.print("expected {any}, found {any}\n", .{ expected_payload, actual_err });
|
||||
return error.TestExpectedEqual;
|
||||
}
|
||||
} else |expected_err| {
|
||||
if (actual) |actual_payload| {
|
||||
std.debug.print("expected {any}, found {any}\n", .{ expected_err, actual_payload });
|
||||
return error.TestExpectedEqual;
|
||||
} else |actual_err| {
|
||||
try expectEqualDeep(expected_err, actual_err);
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
test "expectEqualDeep primitive type" {
|
||||
try expectEqualDeep(1, 1);
|
||||
try expectEqualDeep(true, true);
|
||||
try expectEqualDeep(1.5, 1.5);
|
||||
try expectEqualDeep(u8, u8);
|
||||
try expectEqualDeep(error.Bad, error.Bad);
|
||||
|
||||
// optional
|
||||
{
|
||||
const foo: ?u32 = 1;
|
||||
const bar: ?u32 = 1;
|
||||
try expectEqualDeep(foo, bar);
|
||||
try expectEqualDeep(?u32, ?u32);
|
||||
}
|
||||
// function type
|
||||
{
|
||||
const fnType = struct {
|
||||
fn foo() void {
|
||||
unreachable;
|
||||
}
|
||||
}.foo;
|
||||
try expectEqualDeep(fnType, fnType);
|
||||
}
|
||||
}
|
||||
|
||||
test "expectEqualDeep pointer" {
|
||||
const a = 1;
|
||||
const b = 1;
|
||||
try expectEqualDeep(&a, &b);
|
||||
}
|
||||
|
||||
test "expectEqualDeep composite type" {
|
||||
try expectEqualDeep("abc", "abc");
|
||||
const s1: []const u8 = "abc";
|
||||
const s2 = "abcd";
|
||||
const s3: []const u8 = s2[0..3];
|
||||
try expectEqualDeep(s1, s3);
|
||||
|
||||
const TestStruct = struct { s: []const u8 };
|
||||
try expectEqualDeep(TestStruct{ .s = "abc" }, TestStruct{ .s = "abc" });
|
||||
try expectEqualDeep([_][]const u8{ "a", "b", "c" }, [_][]const u8{ "a", "b", "c" });
|
||||
|
||||
// vector
|
||||
try expectEqualDeep(@splat(4, @as(u32, 4)), @splat(4, @as(u32, 4)));
|
||||
|
||||
// nested array
|
||||
{
|
||||
const a = [2][2]f32{
|
||||
[_]f32{ 1.0, 0.0 },
|
||||
[_]f32{ 0.0, 1.0 },
|
||||
};
|
||||
|
||||
const b = [2][2]f32{
|
||||
[_]f32{ 1.0, 0.0 },
|
||||
[_]f32{ 0.0, 1.0 },
|
||||
};
|
||||
|
||||
try expectEqualDeep(a, b);
|
||||
try expectEqualDeep(&a, &b);
|
||||
}
|
||||
}
|
||||
|
||||
fn printIndicatorLine(source: []const u8, indicator_index: usize) void {
|
||||
const line_begin_index = if (std.mem.lastIndexOfScalar(u8, source[0..indicator_index], '\n')) |line_begin|
|
||||
line_begin + 1
|
||||
|
||||
Reference in New Issue
Block a user