commit 163ebe044b76ada70b2bee2e17b9f3e948d54754 (tree)
parent 9760068826e01e5540da9168d2f02e15957a99cc
Author: Henry John Kupty <hkupty@users.noreply.github.com>
Date: Tue, 7 Oct 2025 18:32:13 +0200
std.mem.countScalar: rework to benefit from simd (#25477)
`findScalarPos` might do repetitive work, even if using simd. For
example, when searching the string `/abcde/fghijk/lm` for the character
`/`, a 16-byte wide search would yield `1000001000000100` but would only
count the first `1` and re-search the remaining of the string.
When testing locally, the difference was quite significative:
```
count scalar
5737 iterations 522.83us per iterations
0 bytes per iteration
worst: 2370us median: 512us stddev: 107.64us
count v2
38333 iterations 78.03us per iterations
0 bytes per iteration
worst: 713us median: 76us stddev: 10.62us
count scalar v2
99565 iterations 29.80us per iterations
0 bytes per iteration
worst: 41us median: 29us stddev: 1.04us
```
Note that `count v2` is a simpler string search, similar to the
remaining version of the simd approach:
```
pub fn countV2(comptime T: type, haystack: []const T, needle: T) usize {
const n = haystack.len;
if (n < 1) return 0;
var count: usize = 0;
for (haystack[0..n]) |item| {
count += @intFromBool(item == needle);
}
return count;
}
```
Which implies the compiler yields some optimized code for a simpler loop
that is more performant than the `findScalarPos`-based approach, hence
the usage of iterative approach for the remaining of the haystack.
Co-authored-by: StAlKeR7779 <stalkek7779@yandex.ru>
Diffstat:
1 file changed, 17 insertions(+), 3 deletions(-)
diff --git a/lib/std/mem.zig b/lib/std/mem.zig
@@ -1706,12 +1706,26 @@ test count {
/// Returns the number of needles inside the haystack
pub fn countScalar(comptime T: type, haystack: []const T, needle: T) usize {
+ const n = haystack.len;
var i: usize = 0;
var found: usize = 0;
- while (findScalarPos(T, haystack, i, needle)) |idx| {
- i = idx + 1;
- found += 1;
+ if (use_vectors_for_comparison and
+ (@typeInfo(T) == .int or @typeInfo(T) == .float) and std.math.isPowerOfTwo(@bitSizeOf(T)))
+ {
+ if (std.simd.suggestVectorLength(T)) |block_size| {
+ const Block = @Vector(block_size, T);
+
+ const letter_mask: Block = @splat(needle);
+ while (n - i >= block_size) : (i += block_size) {
+ const haystack_block: Block = haystack[i..][0..block_size].*;
+ found += std.simd.countTrues(letter_mask == haystack_block);
+ }
+ }
+ }
+
+ for (haystack[i..n]) |item| {
+ found += @intFromBool(item == needle);
}
return found;