extract ThreadPool and WaitGroup from compiler to std lib
This commit is contained in:
152
lib/std/Thread/Pool.zig
Normal file
152
lib/std/Thread/Pool.zig
Normal file
@@ -0,0 +1,152 @@
|
||||
const std = @import("std");
|
||||
const builtin = @import("builtin");
|
||||
const Pool = @This();
|
||||
const WaitGroup = @import("WaitGroup.zig");
|
||||
|
||||
mutex: std.Thread.Mutex = .{},
|
||||
cond: std.Thread.Condition = .{},
|
||||
run_queue: RunQueue = .{},
|
||||
is_running: bool = true,
|
||||
allocator: std.mem.Allocator,
|
||||
threads: []std.Thread,
|
||||
|
||||
const RunQueue = std.SinglyLinkedList(Runnable);
|
||||
const Runnable = struct {
|
||||
runFn: RunProto,
|
||||
};
|
||||
|
||||
const RunProto = *const fn (*Runnable) void;
|
||||
|
||||
pub fn init(pool: *Pool, allocator: std.mem.Allocator) !void {
|
||||
pool.* = .{
|
||||
.allocator = allocator,
|
||||
.threads = &[_]std.Thread{},
|
||||
};
|
||||
|
||||
if (builtin.single_threaded) {
|
||||
return;
|
||||
}
|
||||
|
||||
const thread_count = std.math.max(1, std.Thread.getCpuCount() catch 1);
|
||||
pool.threads = try allocator.alloc(std.Thread, thread_count);
|
||||
errdefer allocator.free(pool.threads);
|
||||
|
||||
// kill and join any threads we spawned previously on error.
|
||||
var spawned: usize = 0;
|
||||
errdefer pool.join(spawned);
|
||||
|
||||
for (pool.threads) |*thread| {
|
||||
thread.* = try std.Thread.spawn(.{}, worker, .{pool});
|
||||
spawned += 1;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deinit(pool: *Pool) void {
|
||||
pool.join(pool.threads.len); // kill and join all threads.
|
||||
pool.* = undefined;
|
||||
}
|
||||
|
||||
fn join(pool: *Pool, spawned: usize) void {
|
||||
if (builtin.single_threaded) {
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
pool.mutex.lock();
|
||||
defer pool.mutex.unlock();
|
||||
|
||||
// ensure future worker threads exit the dequeue loop
|
||||
pool.is_running = false;
|
||||
}
|
||||
|
||||
// wake up any sleeping threads (this can be done outside the mutex)
|
||||
// then wait for all the threads we know are spawned to complete.
|
||||
pool.cond.broadcast();
|
||||
for (pool.threads[0..spawned]) |thread| {
|
||||
thread.join();
|
||||
}
|
||||
|
||||
pool.allocator.free(pool.threads);
|
||||
}
|
||||
|
||||
pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) !void {
|
||||
if (builtin.single_threaded) {
|
||||
@call(.auto, func, args);
|
||||
return;
|
||||
}
|
||||
|
||||
const Args = @TypeOf(args);
|
||||
const Closure = struct {
|
||||
arguments: Args,
|
||||
pool: *Pool,
|
||||
run_node: RunQueue.Node = .{ .data = .{ .runFn = runFn } },
|
||||
|
||||
fn runFn(runnable: *Runnable) void {
|
||||
const run_node = @fieldParentPtr(RunQueue.Node, "data", runnable);
|
||||
const closure = @fieldParentPtr(@This(), "run_node", run_node);
|
||||
@call(.auto, func, closure.arguments);
|
||||
|
||||
// The thread pool's allocator is protected by the mutex.
|
||||
const mutex = &closure.pool.mutex;
|
||||
mutex.lock();
|
||||
defer mutex.unlock();
|
||||
|
||||
closure.pool.allocator.destroy(closure);
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
pool.mutex.lock();
|
||||
defer pool.mutex.unlock();
|
||||
|
||||
const closure = try pool.allocator.create(Closure);
|
||||
closure.* = .{
|
||||
.arguments = args,
|
||||
.pool = pool,
|
||||
};
|
||||
|
||||
pool.run_queue.prepend(&closure.run_node);
|
||||
}
|
||||
|
||||
// Notify waiting threads outside the lock to try and keep the critical section small.
|
||||
pool.cond.signal();
|
||||
}
|
||||
|
||||
fn worker(pool: *Pool) void {
|
||||
pool.mutex.lock();
|
||||
defer pool.mutex.unlock();
|
||||
|
||||
while (true) {
|
||||
while (pool.run_queue.popFirst()) |run_node| {
|
||||
// Temporarily unlock the mutex in order to execute the run_node
|
||||
pool.mutex.unlock();
|
||||
defer pool.mutex.lock();
|
||||
|
||||
const runFn = run_node.data.runFn;
|
||||
runFn(&run_node.data);
|
||||
}
|
||||
|
||||
// Stop executing instead of waiting if the thread pool is no longer running.
|
||||
if (pool.is_running) {
|
||||
pool.cond.wait(&pool.mutex);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn waitAndWork(pool: *Pool, wait_group: *WaitGroup) void {
|
||||
while (!wait_group.isDone()) {
|
||||
if (blk: {
|
||||
pool.mutex.lock();
|
||||
defer pool.mutex.unlock();
|
||||
break :blk pool.run_queue.popFirst();
|
||||
}) |run_node| {
|
||||
run_node.data.runFn(&run_node.data);
|
||||
continue;
|
||||
}
|
||||
|
||||
wait_group.wait();
|
||||
return;
|
||||
}
|
||||
}
|
||||
46
lib/std/Thread/WaitGroup.zig
Normal file
46
lib/std/Thread/WaitGroup.zig
Normal file
@@ -0,0 +1,46 @@
|
||||
const std = @import("std");
|
||||
const Atomic = std.atomic.Atomic;
|
||||
const assert = std.debug.assert;
|
||||
const WaitGroup = @This();
|
||||
|
||||
const is_waiting: usize = 1 << 0;
|
||||
const one_pending: usize = 1 << 1;
|
||||
|
||||
state: Atomic(usize) = Atomic(usize).init(0),
|
||||
event: std.Thread.ResetEvent = .{},
|
||||
|
||||
pub fn start(self: *WaitGroup) void {
|
||||
const state = self.state.fetchAdd(one_pending, .Monotonic);
|
||||
assert((state / one_pending) < (std.math.maxInt(usize) / one_pending));
|
||||
}
|
||||
|
||||
pub fn finish(self: *WaitGroup) void {
|
||||
const state = self.state.fetchSub(one_pending, .Release);
|
||||
assert((state / one_pending) > 0);
|
||||
|
||||
if (state == (one_pending | is_waiting)) {
|
||||
self.state.fence(.Acquire);
|
||||
self.event.set();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn wait(self: *WaitGroup) void {
|
||||
var state = self.state.fetchAdd(is_waiting, .Acquire);
|
||||
assert(state & is_waiting == 0);
|
||||
|
||||
if ((state / one_pending) > 0) {
|
||||
self.event.wait();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reset(self: *WaitGroup) void {
|
||||
self.state.store(0, .Monotonic);
|
||||
self.event.reset();
|
||||
}
|
||||
|
||||
pub fn isDone(wg: *WaitGroup) bool {
|
||||
const state = wg.state.load(.Acquire);
|
||||
assert(state & is_waiting == 0);
|
||||
|
||||
return (state / one_pending) == 0;
|
||||
}
|
||||
Reference in New Issue
Block a user