From 700ea694b293565ececb7571e4d9613d2c143ca6 Mon Sep 17 00:00:00 2001 From: Niles Salter Date: Tue, 13 Jun 2023 14:55:58 -0600 Subject: [PATCH] Fix pdqSort+heapSort for ranges besides 0..len (#15982) --- lib/std/sort.zig | 76 ++++++++++++++++++++++++++++++++++++-------- lib/std/sort/pdq.zig | 18 +++++------ 2 files changed, 71 insertions(+), 23 deletions(-) diff --git a/lib/std/sort.zig b/lib/std/sort.zig index bf2bf40f89..813bad0741 100644 --- a/lib/std/sort.zig +++ b/lib/std/sort.zig @@ -74,29 +74,29 @@ pub fn heap( /// Sorts in ascending order with respect to the given `lessThan` function. pub fn heapContext(a: usize, b: usize, context: anytype) void { // build the heap in linear time. - var i = b / 2; - while (i > a) : (i -= 1) { - siftDown(i - 1, b, context); + var i = a + (b - a) / 2; + while (i > a) { + i -= 1; + siftDown(a, i, b, context); } // pop maximal elements from the heap. i = b; - while (i > a) : (i -= 1) { - context.swap(a, i - 1); - siftDown(a, i - 1, context); + while (i > a) { + i -= 1; + context.swap(a, i); + siftDown(a, a, i, context); } } -fn siftDown(root: usize, n: usize, context: anytype) void { +fn siftDown(a: usize, root: usize, n: usize, context: anytype) void { var node = root; while (true) { - var child = 2 * node + 1; + var child = a + 2 * (node - a) + 1; if (child >= n) break; // choose the greater child. - if (child + 1 < n and context.lessThan(child, child + 1)) { - child += 1; - } + child += @boolToInt(child + 1 < n and context.lessThan(child, child + 1)); // stop if the invariant holds at `node`. if (!context.lessThan(node, child)) break; @@ -138,6 +138,13 @@ const sort_funcs = &[_]fn (comptime type, anytype, anytype, comptime anytype) vo heap, }; +const context_sort_funcs = &[_]fn (usize, usize, anytype) void{ + // blockContext, + pdqContext, + insertionContext, + heapContext, +}; + const IdAndValue = struct { id: usize, value: i32, @@ -248,11 +255,15 @@ test "sort" { &[_]i32{ 2, 1, 3 }, &[_]i32{ 1, 2, 3 }, }, + &[_][]const i32{ + &[_]i32{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 55, 32, 39, 58, 21, 88, 43, 22, 59 }, + &[_]i32{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 21, 22, 32, 39, 43, 55, 58, 59, 88 }, + }, }; inline for (sort_funcs) |sortFn| { for (u8cases) |case| { - var buf: [8]u8 = undefined; + var buf: [20]u8 = undefined; const slice = buf[0..case[0].len]; @memcpy(slice, case[0]); sortFn(u8, slice, {}, asc_u8); @@ -260,7 +271,7 @@ test "sort" { } for (i32cases) |case| { - var buf: [8]i32 = undefined; + var buf: [20]i32 = undefined; const slice = buf[0..case[0].len]; @memcpy(slice, case[0]); sortFn(i32, slice, {}, asc_i32); @@ -308,6 +319,45 @@ test "sort descending" { } } +test "sort with context in the middle of a slice" { + const Context = struct { + items: []i32, + + pub fn lessThan(ctx: @This(), a: usize, b: usize) bool { + return ctx.items[a] < ctx.items[b]; + } + + pub fn swap(ctx: @This(), a: usize, b: usize) void { + return mem.swap(i32, &ctx.items[a], &ctx.items[b]); + } + }; + + const i32cases = [_][]const []const i32{ + &[_][]const i32{ + &[_]i32{ 0, 1, 8, 3, 6, 5, 4, 2, 9, 7, 10, 55, 32, 39, 58, 21, 88, 43, 22, 59 }, + &[_]i32{ 50, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 21, 22, 32, 39, 43, 55, 58, 59, 88 }, + }, + }; + + const ranges = [_]struct { start: usize, end: usize }{ + .{ .start = 10, .end = 20 }, + .{ .start = 1, .end = 11 }, + .{ .start = 3, .end = 7 }, + }; + + inline for (context_sort_funcs) |sortFn| { + for (i32cases) |case| { + for (ranges) |range| { + var buf: [20]i32 = undefined; + const slice = buf[0..case[0].len]; + @memcpy(slice, case[0]); + sortFn(range.start, range.end, Context{ .items = slice }); + try testing.expectEqualSlices(i32, slice[range.start..range.end], case[1][range.start..range.end]); + } + } + } +} + test "sort fuzz testing" { var prng = std.rand.DefaultPrng.init(0x12345678); const random = prng.random(); diff --git a/lib/std/sort/pdq.zig b/lib/std/sort/pdq.zig index e7042b0c76..23678a79c6 100644 --- a/lib/std/sort/pdq.zig +++ b/lib/std/sort/pdq.zig @@ -43,7 +43,7 @@ pub fn pdqContext(a: usize, b: usize, context: anytype) void { // slices of up to this length get sorted using insertion sort. const max_insertion = 24; // number of allowed imbalanced partitions before switching to heap sort. - const max_limit = std.math.floorPowerOfTwo(usize, b) + 1; + const max_limit = std.math.floorPowerOfTwo(usize, b - a) + 1; // set upper bound on stack memory usage. const Range = struct { a: usize, b: usize, limit: usize }; @@ -100,7 +100,7 @@ pub fn pdqContext(a: usize, b: usize, context: anytype) void { // if the chosen pivot is equal to the predecessor, then it's the smallest element in the // slice. Partition the slice into elements equal to and elements greater than the pivot. // This case is usually hit when the slice contains many duplicate elements. - if (range.a > 0 and !context.lessThan(range.a - 1, pivot)) { + if (range.a > a and !context.lessThan(range.a - 1, pivot)) { range.a = partitionEqual(range.a, range.b, pivot, context); continue; } @@ -284,13 +284,13 @@ fn chosePivot(a: usize, b: usize, pivot: *usize, context: anytype) Hint { if (len >= 8) { if (len >= shortest_ninther) { // find medians in the neighborhoods of `i`, `j` and `k` - i = sort3(i - 1, i, i + 1, &swaps, context); - j = sort3(j - 1, j, j + 1, &swaps, context); - k = sort3(k - 1, k, k + 1, &swaps, context); + sort3(i - 1, i, i + 1, &swaps, context); + sort3(j - 1, j, j + 1, &swaps, context); + sort3(k - 1, k, k + 1, &swaps, context); } - // find the median among `i`, `j` and `k` - j = sort3(i, j, k, &swaps, context); + // find the median among `i`, `j` and `k` and stores it in `j` + sort3(i, j, k, &swaps, context); } pivot.* = j; @@ -301,7 +301,7 @@ fn chosePivot(a: usize, b: usize, pivot: *usize, context: anytype) Hint { }; } -fn sort3(a: usize, b: usize, c: usize, swaps: *usize, context: anytype) usize { +fn sort3(a: usize, b: usize, c: usize, swaps: *usize, context: anytype) void { if (context.lessThan(b, a)) { swaps.* += 1; context.swap(b, a); @@ -316,8 +316,6 @@ fn sort3(a: usize, b: usize, c: usize, swaps: *usize, context: anytype) usize { swaps.* += 1; context.swap(b, a); } - - return b; } fn reverseRange(a: usize, b: usize, context: anytype) void {