commit 2ea55d7153b9832e7f00c8a85ca4941ccde6b0d6 (tree)
parent d828115dabf3b06711788dc2f424a1ef3cedd6a3
Author: Andrew Kelley <andrew@ziglang.org>
Date: Fri, 21 Nov 2025 20:56:29 -0800
Merge pull request #25998 from ziglang/std.Io.Threaded-async-guarantee
std.Io: guarantee when async() returns, task is already completed or has been successfully assigned a unit of concurrency
Diffstat:
5 files changed, 139 insertions(+), 118 deletions(-)
diff --git a/lib/std/Io.zig b/lib/std/Io.zig
@@ -580,6 +580,9 @@ pub const VTable = struct {
/// If it returns `null` it means `result` has been already populated and
/// `await` will be a no-op.
///
+ /// When this function returns non-null, the implementation guarantees that
+ /// a unit of concurrency has been assigned to the returned task.
+ ///
/// Thread-safe.
async: *const fn (
/// Corresponds to `Io.userdata`.
@@ -1024,6 +1027,10 @@ pub const Group = struct {
///
/// `function` *may* be called immediately, before `async` returns.
///
+ /// When this function returns, it is guaranteed that `function` has
+ /// already been called and completed, or it has successfully been assigned
+ /// a unit of concurrency.
+ ///
/// After this is called, `wait` or `cancel` must be called before the
/// group is deinitialized.
///
@@ -1094,6 +1101,10 @@ pub fn Select(comptime U: type) type {
///
/// `function` *may* be called immediately, before `async` returns.
///
+ /// When this function returns, it is guaranteed that `function` has
+ /// already been called and completed, or it has successfully been
+ /// assigned a unit of concurrency.
+ ///
/// After this is called, `wait` or `cancel` must be called before the
/// select is deinitialized.
///
@@ -1524,8 +1535,11 @@ pub fn Queue(Elem: type) type {
/// not guaranteed to be available until `await` is called.
///
/// `function` *may* be called immediately, before `async` returns. This has
-/// weaker guarantees than `concurrent`, making more portable and
-/// reusable.
+/// weaker guarantees than `concurrent`, making more portable and reusable.
+///
+/// When this function returns, it is guaranteed that `function` has already
+/// been called and completed, or it has successfully been assigned a unit of
+/// concurrency.
///
/// See also:
/// * `Group`
diff --git a/lib/std/Io/Threaded.zig b/lib/std/Io/Threaded.zig
@@ -13,6 +13,7 @@ const net = std.Io.net;
const HostName = std.Io.net.HostName;
const IpAddress = std.Io.net.IpAddress;
const Allocator = std.mem.Allocator;
+const Alignment = std.mem.Alignment;
const assert = std.debug.assert;
const posix = std.posix;
@@ -22,10 +23,30 @@ mutex: std.Thread.Mutex = .{},
cond: std.Thread.Condition = .{},
run_queue: std.SinglyLinkedList = .{},
join_requested: bool = false,
-threads: std.ArrayList(std.Thread),
stack_size: usize,
-cpu_count: std.Thread.CpuCountError!usize,
-concurrent_count: usize,
+/// All threads are spawned detached; this is how we wait until they all exit.
+wait_group: std.Thread.WaitGroup = .{},
+/// Maximum thread pool size (excluding main thread) when dispatching async
+/// tasks. Until this limit, calls to `Io.async` when all threads are busy will
+/// cause a new thread to be spawned and permanently added to the pool. After
+/// this limit, calls to `Io.async` when all threads are busy run the task
+/// immediately.
+///
+/// Defaults to a number equal to logical CPU cores.
+async_limit: Io.Limit,
+/// Maximum thread pool size (excluding main thread) for dispatching concurrent
+/// tasks. Until this limit, calls to `Io.concurrent` will increase the thread
+/// pool size.
+///
+/// concurrent tasks. After this number, calls to `Io.concurrent` return
+/// `error.ConcurrencyUnavailable`.
+concurrent_limit: Io.Limit = .unlimited,
+/// Error from calling `std.Thread.getCpuCount` in `init`.
+cpu_count_error: ?std.Thread.CpuCountError,
+/// Number of threads that are unavailable to take tasks. To calculate
+/// available count, subtract this from either `async_limit` or
+/// `concurrent_limit`.
+busy_count: usize = 0,
wsa: if (is_windows) Wsa else struct {} = .{},
@@ -70,8 +91,6 @@ const Closure = struct {
start: Start,
node: std.SinglyLinkedList.Node = .{},
cancel_tid: CancelId,
- /// Whether this task bumps minimum number of threads in the pool.
- is_concurrent: bool,
const Start = *const fn (*Closure) void;
@@ -90,8 +109,6 @@ const Closure = struct {
}
};
-pub const InitError = std.Thread.CpuCountError || Allocator.Error;
-
/// Related:
/// * `init_single_threaded`
pub fn init(
@@ -103,21 +120,20 @@ pub fn init(
/// here.
gpa: Allocator,
) Threaded {
+ if (builtin.single_threaded) return .init_single_threaded;
+
+ const cpu_count = std.Thread.getCpuCount();
+
var t: Threaded = .{
.allocator = gpa,
- .threads = .empty,
.stack_size = std.Thread.SpawnConfig.default_stack_size,
- .cpu_count = std.Thread.getCpuCount(),
- .concurrent_count = 0,
+ .async_limit = if (cpu_count) |n| .limited(n - 1) else |_| .nothing,
+ .cpu_count_error = if (cpu_count) |_| null else |e| e,
.old_sig_io = undefined,
.old_sig_pipe = undefined,
.have_signal_handler = false,
};
- if (t.cpu_count) |n| {
- t.threads.ensureTotalCapacityPrecise(gpa, n - 1) catch {};
- } else |_| {}
-
if (posix.Sigaction != void) {
// This causes sending `posix.SIG.IO` to thread to interrupt blocking
// syscalls, returning `posix.E.INTR`.
@@ -142,19 +158,17 @@ pub fn init(
/// * `deinit` is safe, but unnecessary to call.
pub const init_single_threaded: Threaded = .{
.allocator = .failing,
- .threads = .empty,
.stack_size = std.Thread.SpawnConfig.default_stack_size,
- .cpu_count = 1,
- .concurrent_count = 0,
+ .async_limit = .nothing,
+ .cpu_count_error = null,
+ .concurrent_limit = .nothing,
.old_sig_io = undefined,
.old_sig_pipe = undefined,
.have_signal_handler = false,
};
pub fn deinit(t: *Threaded) void {
- const gpa = t.allocator;
t.join();
- t.threads.deinit(gpa);
if (is_windows and t.wsa.status == .initialized) {
if (ws2_32.WSACleanup() != 0) recoverableOsBugDetected();
}
@@ -173,10 +187,12 @@ fn join(t: *Threaded) void {
t.join_requested = true;
}
t.cond.broadcast();
- for (t.threads.items) |thread| thread.join();
+ t.wait_group.wait();
}
fn worker(t: *Threaded) void {
+ defer t.wait_group.finish();
+
t.mutex.lock();
defer t.mutex.unlock();
@@ -184,12 +200,9 @@ fn worker(t: *Threaded) void {
while (t.run_queue.popFirst()) |closure_node| {
t.mutex.unlock();
const closure: *Closure = @fieldParentPtr("node", closure_node);
- const is_concurrent = closure.is_concurrent;
closure.start(closure);
t.mutex.lock();
- if (is_concurrent) {
- t.concurrent_count -= 1;
- }
+ t.busy_count -= 1;
}
if (t.join_requested) break;
t.cond.wait(&t.mutex);
@@ -387,7 +400,7 @@ const AsyncClosure = struct {
func: *const fn (context: *anyopaque, result: *anyopaque) void,
reset_event: ResetEvent,
select_condition: ?*ResetEvent,
- context_alignment: std.mem.Alignment,
+ context_alignment: Alignment,
result_offset: usize,
alloc_len: usize,
@@ -432,11 +445,10 @@ const AsyncClosure = struct {
fn init(
gpa: Allocator,
- mode: enum { async, concurrent },
result_len: usize,
- result_alignment: std.mem.Alignment,
+ result_alignment: Alignment,
context: []const u8,
- context_alignment: std.mem.Alignment,
+ context_alignment: Alignment,
func: *const fn (context: *const anyopaque, result: *anyopaque) void,
) Allocator.Error!*AsyncClosure {
const max_context_misalignment = context_alignment.toByteUnits() -| @alignOf(AsyncClosure);
@@ -454,10 +466,6 @@ const AsyncClosure = struct {
.closure = .{
.cancel_tid = .none,
.start = start,
- .is_concurrent = switch (mode) {
- .async => false,
- .concurrent => true,
- },
},
.func = func,
.context_alignment = context_alignment,
@@ -470,10 +478,15 @@ const AsyncClosure = struct {
return ac;
}
- fn waitAndDeinit(ac: *AsyncClosure, gpa: Allocator, result: []u8) void {
- ac.reset_event.waitUncancelable();
+ fn waitAndDeinit(ac: *AsyncClosure, t: *Threaded, result: []u8) void {
+ ac.reset_event.wait(t) catch |err| switch (err) {
+ error.Canceled => {
+ ac.closure.requestCancel();
+ ac.reset_event.waitUncancelable();
+ },
+ };
@memcpy(result, ac.resultPointer()[0..result.len]);
- ac.deinit(gpa);
+ ac.deinit(t.allocator);
}
fn deinit(ac: *AsyncClosure, gpa: Allocator) void {
@@ -485,60 +498,50 @@ const AsyncClosure = struct {
fn async(
userdata: ?*anyopaque,
result: []u8,
- result_alignment: std.mem.Alignment,
+ result_alignment: Alignment,
context: []const u8,
- context_alignment: std.mem.Alignment,
+ context_alignment: Alignment,
start: *const fn (context: *const anyopaque, result: *anyopaque) void,
) ?*Io.AnyFuture {
- if (builtin.single_threaded) {
+ const t: *Threaded = @ptrCast(@alignCast(userdata));
+ if (builtin.single_threaded or t.async_limit == .nothing) {
start(context.ptr, result.ptr);
return null;
}
-
- const t: *Threaded = @ptrCast(@alignCast(userdata));
- const cpu_count = t.cpu_count catch {
- return concurrent(userdata, result.len, result_alignment, context, context_alignment, start) catch {
- start(context.ptr, result.ptr);
- return null;
- };
- };
-
const gpa = t.allocator;
- const ac = AsyncClosure.init(gpa, .async, result.len, result_alignment, context, context_alignment, start) catch {
+ const ac = AsyncClosure.init(gpa, result.len, result_alignment, context, context_alignment, start) catch {
start(context.ptr, result.ptr);
return null;
};
t.mutex.lock();
- const thread_capacity = cpu_count - 1 + t.concurrent_count;
+ const busy_count = t.busy_count;
- t.threads.ensureTotalCapacityPrecise(gpa, thread_capacity) catch {
+ if (busy_count >= @intFromEnum(t.async_limit)) {
t.mutex.unlock();
ac.deinit(gpa);
start(context.ptr, result.ptr);
return null;
- };
+ }
- t.run_queue.prepend(&ac.closure.node);
+ t.busy_count = busy_count + 1;
- if (t.threads.items.len < thread_capacity) {
+ const pool_size = t.wait_group.value();
+ if (pool_size - busy_count == 0) {
+ t.wait_group.start();
const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch {
- if (t.threads.items.len == 0) {
- assert(t.run_queue.popFirst() == &ac.closure.node);
- t.mutex.unlock();
- ac.deinit(gpa);
- start(context.ptr, result.ptr);
- return null;
- }
- // Rely on other workers to do it.
+ t.wait_group.finish();
+ t.busy_count = busy_count;
t.mutex.unlock();
- t.cond.signal();
- return @ptrCast(ac);
+ ac.deinit(gpa);
+ start(context.ptr, result.ptr);
+ return null;
};
- t.threads.appendAssumeCapacity(thread);
+ thread.detach();
}
+ t.run_queue.prepend(&ac.closure.node);
t.mutex.unlock();
t.cond.signal();
return @ptrCast(ac);
@@ -547,45 +550,42 @@ fn async(
fn concurrent(
userdata: ?*anyopaque,
result_len: usize,
- result_alignment: std.mem.Alignment,
+ result_alignment: Alignment,
context: []const u8,
- context_alignment: std.mem.Alignment,
+ context_alignment: Alignment,
start: *const fn (context: *const anyopaque, result: *anyopaque) void,
) Io.ConcurrentError!*Io.AnyFuture {
if (builtin.single_threaded) return error.ConcurrencyUnavailable;
const t: *Threaded = @ptrCast(@alignCast(userdata));
- const cpu_count = t.cpu_count catch 1;
const gpa = t.allocator;
- const ac = AsyncClosure.init(gpa, .concurrent, result_len, result_alignment, context, context_alignment, start) catch {
+ const ac = AsyncClosure.init(gpa, result_len, result_alignment, context, context_alignment, start) catch
return error.ConcurrencyUnavailable;
- };
+ errdefer ac.deinit(gpa);
t.mutex.lock();
+ defer t.mutex.unlock();
- t.concurrent_count += 1;
- const thread_capacity = cpu_count - 1 + t.concurrent_count;
+ const busy_count = t.busy_count;
- t.threads.ensureTotalCapacity(gpa, thread_capacity) catch {
- t.mutex.unlock();
- ac.deinit(gpa);
+ if (busy_count >= @intFromEnum(t.concurrent_limit))
return error.ConcurrencyUnavailable;
- };
- t.run_queue.prepend(&ac.closure.node);
+ t.busy_count = busy_count + 1;
+ errdefer t.busy_count = busy_count;
- if (t.threads.items.len < thread_capacity) {
- const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch {
- assert(t.run_queue.popFirst() == &ac.closure.node);
- t.mutex.unlock();
- ac.deinit(gpa);
+ const pool_size = t.wait_group.value();
+ if (pool_size - busy_count == 0) {
+ t.wait_group.start();
+ errdefer t.wait_group.finish();
+
+ const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch
return error.ConcurrencyUnavailable;
- };
- t.threads.appendAssumeCapacity(thread);
+ thread.detach();
}
- t.mutex.unlock();
+ t.run_queue.prepend(&ac.closure.node);
t.cond.signal();
return @ptrCast(ac);
}
@@ -597,7 +597,7 @@ const GroupClosure = struct {
/// Points to sibling `GroupClosure`. Used for walking the group to cancel all.
node: std.SinglyLinkedList.Node,
func: *const fn (*Io.Group, context: *anyopaque) void,
- context_alignment: std.mem.Alignment,
+ context_alignment: Alignment,
alloc_len: usize,
fn start(closure: *Closure) void {
@@ -638,7 +638,7 @@ const GroupClosure = struct {
t: *Threaded,
group: *Io.Group,
context: []const u8,
- context_alignment: std.mem.Alignment,
+ context_alignment: Alignment,
func: *const fn (*Io.Group, context: *const anyopaque) void,
) Allocator.Error!*GroupClosure {
const max_context_misalignment = context_alignment.toByteUnits() -| @alignOf(GroupClosure);
@@ -652,7 +652,6 @@ const GroupClosure = struct {
.closure = .{
.cancel_tid = .none,
.start = start,
- .is_concurrent = false,
},
.t = t,
.group = group,
@@ -678,45 +677,48 @@ fn groupAsync(
userdata: ?*anyopaque,
group: *Io.Group,
context: []const u8,
- context_alignment: std.mem.Alignment,
+ context_alignment: Alignment,
start: *const fn (*Io.Group, context: *const anyopaque) void,
) void {
- if (builtin.single_threaded) return start(group, context.ptr);
-
const t: *Threaded = @ptrCast(@alignCast(userdata));
- const cpu_count = t.cpu_count catch 1;
+ if (builtin.single_threaded or t.async_limit == .nothing)
+ return start(group, context.ptr);
const gpa = t.allocator;
- const gc = GroupClosure.init(gpa, t, group, context, context_alignment, start) catch {
+ const gc = GroupClosure.init(gpa, t, group, context, context_alignment, start) catch
return start(group, context.ptr);
- };
t.mutex.lock();
- // Append to the group linked list inside the mutex to make `Io.Group.async` thread-safe.
- gc.node = .{ .next = @ptrCast(@alignCast(group.token)) };
- group.token = &gc.node;
+ const busy_count = t.busy_count;
- const thread_capacity = cpu_count - 1 + t.concurrent_count;
-
- t.threads.ensureTotalCapacityPrecise(gpa, thread_capacity) catch {
+ if (busy_count >= @intFromEnum(t.async_limit)) {
t.mutex.unlock();
gc.deinit(gpa);
return start(group, context.ptr);
- };
+ }
- t.run_queue.prepend(&gc.closure.node);
+ t.busy_count = busy_count + 1;
- if (t.threads.items.len < thread_capacity) {
+ const pool_size = t.wait_group.value();
+ if (pool_size - busy_count == 0) {
+ t.wait_group.start();
const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch {
- assert(t.run_queue.popFirst() == &gc.closure.node);
+ t.wait_group.finish();
+ t.busy_count = busy_count;
t.mutex.unlock();
gc.deinit(gpa);
return start(group, context.ptr);
};
- t.threads.appendAssumeCapacity(thread);
+ thread.detach();
}
+ // Append to the group linked list inside the mutex to make `Io.Group.async` thread-safe.
+ gc.node = .{ .next = @ptrCast(@alignCast(group.token)) };
+ group.token = &gc.node;
+
+ t.run_queue.prepend(&gc.closure.node);
+
// This needs to be done before unlocking the mutex to avoid a race with
// the associated task finishing.
const group_state: *std.atomic.Value(usize) = @ptrCast(&group.state);
@@ -794,25 +796,25 @@ fn await(
userdata: ?*anyopaque,
any_future: *Io.AnyFuture,
result: []u8,
- result_alignment: std.mem.Alignment,
+ result_alignment: Alignment,
) void {
_ = result_alignment;
const t: *Threaded = @ptrCast(@alignCast(userdata));
const closure: *AsyncClosure = @ptrCast(@alignCast(any_future));
- closure.waitAndDeinit(t.allocator, result);
+ closure.waitAndDeinit(t, result);
}
fn cancel(
userdata: ?*anyopaque,
any_future: *Io.AnyFuture,
result: []u8,
- result_alignment: std.mem.Alignment,
+ result_alignment: Alignment,
) void {
_ = result_alignment;
const t: *Threaded = @ptrCast(@alignCast(userdata));
const ac: *AsyncClosure = @ptrCast(@alignCast(any_future));
ac.closure.requestCancel();
- ac.waitAndDeinit(t.allocator, result);
+ ac.waitAndDeinit(t, result);
}
fn cancelRequested(userdata: ?*anyopaque) bool {
diff --git a/lib/std/Io/Threaded/test.zig b/lib/std/Io/Threaded/test.zig
@@ -10,7 +10,7 @@ test "concurrent vs main prevents deadlock via oversubscription" {
defer threaded.deinit();
const io = threaded.io();
- threaded.cpu_count = 1;
+ threaded.async_limit = .nothing;
var queue: Io.Queue(u8) = .init(&.{});
@@ -38,7 +38,7 @@ test "concurrent vs concurrent prevents deadlock via oversubscription" {
defer threaded.deinit();
const io = threaded.io();
- threaded.cpu_count = 1;
+ threaded.async_limit = .nothing;
var queue: Io.Queue(u8) = .init(&.{});
diff --git a/lib/std/Thread.zig b/lib/std/Thread.zig
@@ -1,13 +1,14 @@
-//! This struct represents a kernel thread, and acts as a namespace for concurrency
-//! primitives that operate on kernel threads. For concurrency primitives that support
-//! both evented I/O and async I/O, see the respective names in the top level std namespace.
+//! This struct represents a kernel thread, and acts as a namespace for
+//! concurrency primitives that operate on kernel threads. For concurrency
+//! primitives that interact with the I/O interface, see `std.Io`.
-const std = @import("std.zig");
const builtin = @import("builtin");
-const math = std.math;
-const assert = std.debug.assert;
const target = builtin.target;
const native_os = builtin.os.tag;
+
+const std = @import("std.zig");
+const math = std.math;
+const assert = std.debug.assert;
const posix = std.posix;
const windows = std.os.windows;
const testing = std.testing;
diff --git a/lib/std/Thread/WaitGroup.zig b/lib/std/Thread/WaitGroup.zig
@@ -60,6 +60,10 @@ pub fn isDone(wg: *WaitGroup) bool {
return (state / one_pending) == 0;
}
+pub fn value(wg: *WaitGroup) usize {
+ return wg.state.load(.monotonic) / one_pending;
+}
+
// Spawns a new thread for the task. This is appropriate when the callee
// delegates all work.
pub fn spawnManager(