std: implement detach for WASI-threads

When a thread is detached from the main thread, we automatically
cleanup any allocated memory. For this we first reset the stack-pointer
to the original stack-pointer of the main-thread so we can safely clear
the memory which also contains the thread's stack.
This commit is contained in:
Luuk de Gram
2023-06-23 19:14:55 +02:00
parent 622b7c4746
commit e06ab1b010
2 changed files with 64 additions and 27 deletions

View File

@@ -757,6 +757,8 @@ const WasiThreadImpl = struct {
/// The allocator used to allocate the thread's memory,
/// which is also used during `join` to ensure clean-up.
allocator: std.mem.Allocator,
/// The current state of the thread.
state: State = State.init(.running),
};
/// A meta-data structure used to bootstrap a thread
@@ -775,8 +777,15 @@ const WasiThreadImpl = struct {
/// function upon thread spawn. The above mentioned pointer will be passed
/// to this function pointer as its argument.
call_back: *const fn (usize) void,
/// When a thread is in `detached` state, we must free all of its memory
/// upon thread completion. However, as this is done while still within
/// the thread, we must first jump back to the main thread's stack or else
/// we end up freeing the stack that we're currently using.
original_stack_pointer: [*]u8,
};
const State = Atomic(enum(u8) { running, completed, detached });
fn getCurrentId() Id {
return tls_thread_id;
}
@@ -786,7 +795,11 @@ const WasiThreadImpl = struct {
}
fn detach(self: Impl) void {
_ = self;
switch (self.thread.state.swap(.detached, .SeqCst)) {
.running => {},
.completed => self.join(),
.detached => unreachable,
}
}
fn join(self: Impl) void {
@@ -836,7 +849,7 @@ const WasiThreadImpl = struct {
const Wrapper = struct {
args: @TypeOf(args),
fn entry(ptr: usize) void {
const w = @intToPtr(*@This(), ptr);
const w: *@This() = @ptrFromInt(ptr);
@call(.auto, f, w.args);
}
};
@@ -854,7 +867,7 @@ const WasiThreadImpl = struct {
// start with atleast a single page, which is used as a guard to prevent
// other threads clobbering our new thread.
// Unfortunately, WebAssembly has no notion of read-only segments, so this
// is only a temporary measure until the entire page is "run over".
// is only a best effort.
var bytes: usize = std.wasm.page_size;
bytes = std.mem.alignForward(usize, bytes, 16); // align stack to 16 bytes
@@ -880,16 +893,17 @@ const WasiThreadImpl = struct {
// Allocate the amount of memory required for all meta data.
const allocated_memory = try config.allocator.?.alloc(u8, map_bytes);
const wrapper = @ptrCast(*Wrapper, @alignCast(@alignOf(Wrapper), &allocated_memory[wrapper_offset]));
const wrapper: *Wrapper = @ptrCast(@alignCast(&allocated_memory[wrapper_offset]));
wrapper.* = .{ .args = args };
const instance = @ptrCast(*Instance, @alignCast(@alignOf(Instance), &allocated_memory[instance_offset]));
const instance: *Instance = @ptrCast(@alignCast(&allocated_memory[instance_offset]));
instance.* = .{
.thread = .{ .memory = allocated_memory, .allocator = config.allocator.? },
.tls_offset = tls_offset,
.stack_offset = stack_offset,
.raw_ptr = @ptrToInt(wrapper),
.raw_ptr = @intFromPtr(wrapper),
.call_back = &Wrapper.entry,
.original_stack_pointer = __get_stack_pointer(),
};
const tid = spawnWasiThread(instance);
@@ -903,32 +917,46 @@ const WasiThreadImpl = struct {
return .{ .thread = &instance.thread };
}
/// Bootstrap procedure, called by the HOST environment after thread creation.
/// Bootstrap procedure, called by the host environment after thread creation.
export fn wasi_thread_start(tid: i32, arg: *Instance) void {
__set_stack_pointer(arg.thread.memory.ptr + arg.stack_offset);
__wasm_init_tls(arg.thread.memory.ptr + arg.tls_offset);
WasiThreadImpl.tls_thread_id = @intCast(u32, tid);
@atomicStore(u32, &WasiThreadImpl.tls_thread_id, @intCast(tid), .SeqCst);
// Finished bootstrapping, call user's procedure.
arg.call_back(arg.raw_ptr);
// Thread finished. Reset Thread ID and wake up the main thread if needed.
// We use inline assembly here as we must ensure not to use the stack.
asm volatile (
\\ local.get %[ptr]
\\ i32.const 0
\\ i32.atomic.store 0
:
: [ptr] "r" (&arg.thread.tid.value),
);
asm volatile (
\\ local.get %[ptr]
\\ i32.const 1 # waiters
\\ memory.atomic.notify 0
\\ drop # no need to know the waiters
:
: [ptr] "r" (&arg.thread.tid.value),
);
switch (arg.thread.state.swap(.completed, .SeqCst)) {
.running => {
// reset the Thread ID
asm volatile (
\\ local.get %[ptr]
\\ i32.const 0
\\ i32.atomic.store 0
:
: [ptr] "r" (&arg.thread.tid.value),
);
// Wake the main thread listening to this thread
asm volatile (
\\ local.get %[ptr]
\\ i32.const 1 # waiters
\\ memory.atomic.notify 0
\\ drop # no need to know the waiters
:
: [ptr] "r" (&arg.thread.tid.value),
);
},
.completed => unreachable,
.detached => {
// restore the original stack pointer so we can free the memory
// without having to worry about freeing the stack
__set_stack_pointer(arg.original_stack_pointer);
// Ensure a copy so we don't free the allocator reference itself
var allocator = arg.thread.allocator;
allocator.free(arg.thread.memory);
},
}
}
/// Asks the host to create a new thread for us.
@@ -980,6 +1008,15 @@ const WasiThreadImpl = struct {
: [ptr] "r" (addr),
);
}
/// Returns the current value of the stack pointer
inline fn __get_stack_pointer() [*]u8 {
return asm (
\\ global.get __stack_pointer
\\ local.set %[stack_ptr]
: [stack_ptr] "=r" (-> [*]u8),
);
}
};
const LinuxThreadImpl = struct {

View File

@@ -453,7 +453,7 @@ const WasmImpl = struct {
if (!comptime std.Target.wasm.featureSetHas(builtin.target.cpu.features, .atomics)) {
@compileError("WASI target missing cpu feature 'atomics'");
}
const to: i64 = if (timeout) |to| @intCast(i64, to) else -1;
const to: i64 = if (timeout) |to| @intCast(to) else -1;
const result = asm (
\\local.get %[ptr]
\\local.get %[expected]
@@ -462,7 +462,7 @@ const WasmImpl = struct {
\\local.set %[ret]
: [ret] "=r" (-> u32),
: [ptr] "r" (&ptr.value),
[expected] "r" (@bitCast(i32, expect)),
[expected] "r" (@as(i32, @bitCast(expect))),
[timeout] "r" (to),
);
switch (result) {