Skip to content

Commit

Permalink
Update the host ABI of Wasmtime to return failure (#9675)
Browse files Browse the repository at this point in the history
* Update the host ABI of Wasmtime to return failure

This commit updates the "array call" ABI of Wasmtime used to transition
from wasm to the host to explicitly return a `bool` indicating whether
the call succeeded or not. Previously functions would implicitly unwind
via `longjmp` and thus no explicit checks were necessary. The burden of
unwinding is now placed on Cranelift-compiled code itself instead of the
caller.

There are a few pieces of rationale for this change:

* Primarily I was implementing initial integration of Pulley where the
  host `setjmp` and `longjmp` cannot be used to maintain the Pulley
  interpreter state. My initial plan for handling this was to handle
  traps a bit differently in Pulley where having direct access to
  whether a function trapped or not in the interpreter bytecode is
  easier to deal with.

* Additionally it's generally not safe to call `longjmp` from Rust as it
  will not run on-stack destructors. This is ok today in the sense that
  we shouldn't have these in Wasmtime, but directly returning to
  compiled code improves our safety story here a bit where we just won't
  have to deal with the problem of skipping destructors.

* In the future I'd like to move away from `setjmp` and `longjmp`
  entirely in the host to a Cranelift-native solution. This change would
  be required for such a migration anyway so it's my hope to make such a
  Cranelift-based implementation easier in the future. This might be
  necessary, for example, when implementing the `exception-handling`
  proposal for Wasmtime.

Functionally this commit does not remove all usages of
call-`longjmp`-from-Rust. Notably all libcalls and builtins still use
this helper in the trampolines generated in Rust. I plan on going
through the libcalls and updating their ABIs and signatures to reflect
this in follow-up commits. As a result a few preexisting functions that
should go away are marked `#[deprecated]` for internal use in this
commit. I'll be cleaning that up as follow-ups. For now this commit
handles the "hard part" of host functions by ensuring that the new
`bool` return value is plumbed in all the locations correctly.

prtest:full

* Hack around Windows MinGW miscompile (?)

* Run clang-format
  • Loading branch information
alexcrichton authored Nov 26, 2024
1 parent 2b3fe80 commit 6691006
Show file tree
Hide file tree
Showing 34 changed files with 428 additions and 193 deletions.
109 changes: 85 additions & 24 deletions crates/cranelift/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,12 @@ impl wasmtime_environ::Compiler for Compiler {
values_vec_len,
);

builder.ins().return_(&[]);
// At this time wasm functions always signal traps with longjmp or some
// similar sort of routine, so if we got this far that means that the
// function did not trap, so return a "true" value here to indicate that
// to satisfy the ABI of the array-call signature.
let true_return = builder.ins().iconst(ir::types::I8, 1);
builder.ins().return_(&[true_return]);
builder.finalize();

Ok(Box::new(compiler.finish()?))
Expand Down Expand Up @@ -459,13 +464,14 @@ impl wasmtime_environ::Compiler for Compiler {

// Do an indirect call to the callee.
let callee_signature = builder.func.import_signature(array_call_sig);
self.call_indirect_host(
let call = self.call_indirect_host(
&mut builder,
callee_signature,
callee,
&[callee_vmctx, caller_vmctx, args_base, args_len],
);

let succeeded = builder.func.dfg.inst_results(call)[0];
self.raise_if_host_trapped(&mut builder, caller_vmctx, succeeded);
let results =
self.load_values_from_array(wasm_func_ty.returns(), &mut builder, args_base, args_len);
builder.ins().return_(&results);
Expand Down Expand Up @@ -637,27 +643,11 @@ impl wasmtime_environ::Compiler for Compiler {
);
save_last_wasm_exit_fp_and_pc(&mut builder, pointer_type, &ptr_size, limits);

// Now it's time to delegate to the actual builtin. Builtins are stored
// in an array in all `VMContext`s. First load the base pointer of the
// array and then load the entry of the array that corresponds to this
// builtin.
let mem_flags = ir::MemFlags::trusted().with_readonly();
let array_addr = builder.ins().load(
pointer_type,
mem_flags,
vmctx,
i32::from(ptr_size.vmcontext_builtin_functions()),
);
let body_offset = i32::try_from(index.index() * pointer_type.bytes()).unwrap();
let func_addr = builder
.ins()
.load(pointer_type, mem_flags, array_addr, body_offset);

// Forward all our own arguments to the libcall itself, and then return
// all the same results as the libcall.
let block_params = builder.block_params(block0).to_vec();
let host_sig = builder.func.import_signature(host_sig);
let call = self.call_indirect_host(&mut builder, host_sig, func_addr, &block_params);
// Now it's time to delegate to the actual builtin. Forward all our own
// arguments to the libcall itself, and then return all the same results
// as the libcall.
let args = builder.block_params(block0).to_vec();
let call = self.call_builtin(&mut builder, vmctx, &args, index, host_sig);
let results = builder.func.dfg.inst_results(call).to_vec();
builder.ins().return_(&results);
builder.finalize();
Expand Down Expand Up @@ -877,6 +867,77 @@ impl Compiler {
}),
}
}

/// Invokes the `raise` libcall in `vmctx` if the `succeeded` value
/// indicates if a trap happened.
///
/// This helper is used when the host returns back to WebAssembly. The host
/// returns a `bool` indicating whether the call succeeded. If the call
/// failed then Cranelift needs to unwind back to the original invocation
/// point. The unwind right now is then implemented in Wasmtime with a
/// `longjmp`, but one day this might be implemented differently with an
/// unwind inside of Cranelift.
///
/// Additionally in the future for pulley this will emit a special trap
/// opcode for Pulley itself to cease interpretation and exit the
/// interpreter.
fn raise_if_host_trapped(
&self,
builder: &mut FunctionBuilder<'_>,
vmctx: ir::Value,
succeeded: ir::Value,
) {
let trapped_block = builder.create_block();
let continuation_block = builder.create_block();
builder.set_cold_block(trapped_block);
builder
.ins()
.brif(succeeded, continuation_block, &[], trapped_block, &[]);

builder.seal_block(trapped_block);
builder.seal_block(continuation_block);

builder.switch_to_block(trapped_block);
let sigs = BuiltinFunctionSignatures::new(&*self.isa, &self.tunables);
let sig = sigs.host_signature(BuiltinFunctionIndex::raise());
self.call_builtin(builder, vmctx, &[vmctx], BuiltinFunctionIndex::raise(), sig);
builder.ins().trap(TRAP_INTERNAL_ASSERT);

builder.switch_to_block(continuation_block);
}

/// Helper to load the core `builtin` from `vmctx` and invoke it with
/// `args`.
fn call_builtin(
&self,
builder: &mut FunctionBuilder<'_>,
vmctx: ir::Value,
args: &[ir::Value],
builtin: BuiltinFunctionIndex,
sig: ir::Signature,
) -> ir::Inst {
let isa = &*self.isa;
let ptr_size = isa.pointer_bytes();
let pointer_type = isa.pointer_type();

// Builtins are stored in an array in all `VMContext`s. First load the
// base pointer of the array and then load the entry of the array that
// corresponds to this builtin.
let mem_flags = ir::MemFlags::trusted().with_readonly();
let array_addr = builder.ins().load(
pointer_type,
mem_flags,
vmctx,
i32::from(ptr_size.vmcontext_builtin_functions()),
);
let body_offset = i32::try_from(builtin.index() * pointer_type.bytes()).unwrap();
let func_addr = builder
.ins()
.load(pointer_type, mem_flags, array_addr, body_offset);

let sig = builder.func.import_signature(sig);
self.call_indirect_host(builder, sig, func_addr, args)
}
}

struct FunctionCompiler<'a> {
Expand Down
19 changes: 14 additions & 5 deletions crates/cranelift/src/compiler/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ impl<'a> TrampolineCompiler<'a> {
let pointer_type = self.isa.pointer_type();
let args = self.builder.func.dfg.block_params(self.block0).to_vec();
let vmctx = args[0];
let caller_vmctx = args[1];
let wasm_func_ty = self.types[self.signature].unwrap_func();

// Start off by spilling all the wasm arguments into a stack slot to be
Expand Down Expand Up @@ -228,6 +229,9 @@ impl<'a> TrampolineCompiler<'a> {
host_sig.params.push(ir::AbiParam::new(pointer_type));
callee_args.push(values_vec_len);

// return value is a bool whether a trap was raised or not
host_sig.returns.push(ir::AbiParam::new(ir::types::I8));

// Load host function pointer from the vmcontext and then call that
// indirect function pointer with the list of arguments.
let host_fn = self.builder.ins().load(
Expand All @@ -237,11 +241,15 @@ impl<'a> TrampolineCompiler<'a> {
i32::try_from(self.offsets.lowering_callee(index)).unwrap(),
);
let host_sig = self.builder.import_signature(host_sig);
self.compiler
.call_indirect_host(&mut self.builder, host_sig, host_fn, &callee_args);
let call =
self.compiler
.call_indirect_host(&mut self.builder, host_sig, host_fn, &callee_args);
let succeeded = self.builder.func.dfg.inst_results(call)[0];

match self.abi {
Abi::Wasm => {
self.compiler
.raise_if_host_trapped(&mut self.builder, caller_vmctx, succeeded);
// After the host function has returned the results are loaded from
// `values_vec_ptr` and then returned.
let results = self.compiler.load_values_from_array(
Expand All @@ -253,7 +261,7 @@ impl<'a> TrampolineCompiler<'a> {
self.builder.ins().return_(&results);
}
Abi::Array => {
self.builder.ins().return_(&[]);
self.builder.ins().return_(&[succeeded]);
}
}
}
Expand Down Expand Up @@ -522,8 +530,8 @@ impl<'a> TrampolineCompiler<'a> {
self.builder.seal_block(run_destructor_block);

self.builder.switch_to_block(return_block);
self.builder.ins().return_(&[]);
self.builder.seal_block(return_block);
self.abi_store_results(&[]);
}

/// Invokes a host libcall and returns the result.
Expand Down Expand Up @@ -624,7 +632,8 @@ impl<'a> TrampolineCompiler<'a> {
ptr,
len,
);
self.builder.ins().return_(&[]);
let true_value = self.builder.ins().iconst(ir::types::I8, 1);
self.builder.ins().return_(&[true_value]);
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions crates/cranelift/src/func_environ.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ impl<'module_environment> FuncEnvironment<'module_environment> {
) -> Self {
let builtin_functions = BuiltinFunctions::new(isa, tunables);

// This isn't used during translation, so squash the warning about this
// being unused from the compiler.
let _ = BuiltinFunctions::raise;

// Avoid unused warning in default build.
#[cfg(not(feature = "wmemcheck"))]
let _ = wmemcheck;
Expand Down
2 changes: 2 additions & 0 deletions crates/cranelift/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ fn array_call_signature(isa: &dyn TargetIsa) -> ir::Signature {
// of `ValRaw`.
sig.params.push(ir::AbiParam::new(isa.pointer_type()));
sig.params.push(ir::AbiParam::new(isa.pointer_type()));
// boolean return value of whether this function trapped
sig.returns.push(ir::AbiParam::new(ir::types::I8));
sig
}

Expand Down
6 changes: 5 additions & 1 deletion crates/environ/src/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,12 @@ macro_rules! foreach_builtin_function {
#[cfg(feature = "gc")]
table_fill_gc_ref(vmctx: vmctx, table: i32, dst: i64, val: reference, len: i64);

// Raises an unconditional trap.
// Raises an unconditional trap with the specified code.
trap(vmctx: vmctx, code: u8);

// Raises an unconditional trap where the trap information must have
// been previously filled in.
raise(vmctx: vmctx);
}
};
}
Expand Down
17 changes: 7 additions & 10 deletions crates/wasmtime/src/runtime/component/func/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ impl HostFunc {
string_encoding: StringEncoding,
storage: *mut MaybeUninit<ValRaw>,
storage_len: usize,
) where
) -> bool
where
F: Fn(StoreContextMut<T>, P) -> Result<R>,
P: ComponentNamedList + Lift + 'static,
R: ComponentNamedList + Lower + 'static,
Expand Down Expand Up @@ -295,24 +296,19 @@ unsafe fn call_host_and_handle_result<T>(
&Arc<ComponentTypes>,
StoreContextMut<'_, T>,
) -> Result<()>,
) {
) -> bool {
let cx = VMComponentContext::from_opaque(cx);
let instance = (*cx).instance();
let types = (*instance).component_types();
let raw_store = (*instance).store();
let mut store = StoreContextMut(&mut *raw_store.cast());

let res = crate::runtime::vm::catch_unwind_and_longjmp(|| {
crate::runtime::vm::catch_unwind_and_record_trap(|| {
store.0.call_hook(CallHook::CallingHost)?;
let res = func(instance, types, store.as_context_mut());
store.0.call_hook(CallHook::ReturningFromHost)?;
res
});

match res {
Ok(()) => {}
Err(e) => crate::runtime::vm::raise_user_trap(e),
}
})
}

unsafe fn call_host_dynamic<T, F>(
Expand Down Expand Up @@ -435,7 +431,8 @@ extern "C" fn dynamic_entrypoint<T, F>(
string_encoding: StringEncoding,
storage: *mut MaybeUninit<ValRaw>,
storage_len: usize,
) where
) -> bool
where
F: Fn(StoreContextMut<'_, T>, &[Val], &mut [Val]) -> Result<()> + Send + Sync + 'static,
{
let data = data as *const F;
Expand Down
14 changes: 5 additions & 9 deletions crates/wasmtime/src/runtime/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1592,7 +1592,7 @@ impl Func {
/// can pass to the called wasm function, if desired.
pub(crate) fn invoke_wasm_and_catch_traps<T>(
store: &mut StoreContextMut<'_, T>,
closure: impl FnMut(*mut VMContext),
closure: impl FnMut(*mut VMContext) -> bool,
) -> Result<()> {
unsafe {
let exit = enter_wasm(store);
Expand Down Expand Up @@ -2296,7 +2296,8 @@ impl HostContext {
caller_vmctx: *mut VMOpaqueContext,
args: *mut ValRaw,
args_len: usize,
) where
) -> bool
where
F: Fn(Caller<'_, T>, P) -> R + 'static,
P: WasmTyList,
R: WasmRet,
Expand Down Expand Up @@ -2356,15 +2357,10 @@ impl HostContext {

// With nothing else on the stack move `run` into this
// closure and then run it as part of `Caller::with`.
let result = crate::runtime::vm::catch_unwind_and_longjmp(move || {
crate::runtime::vm::catch_unwind_and_record_trap(move || {
let caller_vmctx = VMContext::from_opaque(caller_vmctx);
Caller::with(caller_vmctx, run)
});

match result {
Ok(val) => val,
Err(err) => crate::runtime::vm::raise_user_trap(err),
}
})
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/wasmtime/src/runtime/func/typed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ where
let storage = core::ptr::slice_from_raw_parts_mut(storage, storage_len);
func_ref
.as_ref()
.array_call(VMOpaqueContext::from_vmcontext(caller), storage);
.array_call(VMOpaqueContext::from_vmcontext(caller), storage)
});

let (_, storage) = captures;
Expand Down
33 changes: 6 additions & 27 deletions crates/wasmtime/src/runtime/trampoline/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,13 @@ unsafe extern "C" fn array_call_shim<F>(
caller_vmctx: *mut VMOpaqueContext,
values_vec: *mut ValRaw,
values_vec_len: usize,
) where
) -> bool
where
F: Fn(*mut VMContext, &mut [ValRaw]) -> Result<()> + 'static,
{
// Here we are careful to use `catch_unwind` to ensure Rust panics don't
// unwind past us. The primary reason for this is that Rust considers it UB
// to unwind past an `extern "C"` function. Here we are in an `extern "C"`
// function and the cross into wasm was through an `extern "C"` function at
// the base of the stack as well. We'll need to wait for assorted RFCs and
// language features to enable this to be done in a sound and stable fashion
// before avoiding catching the panic here.
//
// Also note that there are intentionally no local variables on this stack
// frame. The reason for that is that some of the "raise" functions we have
// below will trigger a longjmp, which won't run local destructors if we
// have any. To prevent leaks we avoid having any local destructors by
// avoiding local variables.
let result = crate::runtime::vm::catch_unwind_and_longjmp(|| {
// Be sure to catch Rust panics to manually shepherd them across the wasm
// boundary, and then otherwise delegate as normal.
crate::runtime::vm::catch_unwind_and_record_trap(|| {
let vmctx = VMArrayCallHostFuncContext::from_opaque(vmctx);
// Double-check ourselves in debug mode, but we control
// the `Any` here so an unsafe downcast should also
Expand All @@ -51,18 +41,7 @@ unsafe extern "C" fn array_call_shim<F>(
let state = &*(state as *const _ as *const TrampolineState<F>);
let values_vec = core::slice::from_raw_parts_mut(values_vec, values_vec_len);
(state.func)(VMContext::from_opaque(caller_vmctx), values_vec)
});

match result {
Ok(()) => {}

// If a trap was raised (an error returned from the imported function)
// then we smuggle the trap through `Box<dyn Error>` through to the
// call-site, which gets unwrapped in `Trap::from_runtime` later on as we
// convert from the internal `Trap` type to our own `Trap` type in this
// crate.
Err(err) => crate::runtime::vm::raise_user_trap(err),
}
})
}

pub fn create_array_call_function<F>(
Expand Down
Loading

0 comments on commit 6691006

Please sign in to comment.