diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 0a62635d..4996d4f5 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -24,6 +24,7 @@ export type { export type { LoadOptions, InstantiateOptions, + LoadedSource, InstantiatedSource } from './load' diff --git a/packages/core/src/load.ts b/packages/core/src/load.ts index 6c7947b9..bc73296d 100644 --- a/packages/core/src/load.ts +++ b/packages/core/src/load.ts @@ -4,7 +4,12 @@ import { createNapiModule } from './emnapi/index' import type { CreateOptions, NapiModule } from './emnapi/index' /** @public */ -export interface InstantiatedSource extends WebAssembly.WebAssemblyInstantiatedSource { +export interface LoadedSource extends WebAssembly.WebAssemblyInstantiatedSource { + usedInstance: WebAssembly.Instance +} + +/** @public */ +export interface InstantiatedSource extends LoadedSource { napiModule: NapiModule } @@ -70,13 +75,6 @@ function loadNapiModuleImpl (loadFn: Function, userNapiModule: NapiModule | unde } if (wasi) { - Object.assign( - importObject, - typeof wasi.getImportObject === 'function' - ? wasi.getImportObject() - : { wasi_snapshot_preview1: wasi.wasiImport } - ) - wasiThreads = new WASIThreads( napiModule.childThread ? { @@ -88,6 +86,15 @@ function loadNapiModuleImpl (loadFn: Function, userNapiModule: NapiModule | unde waitThreadStart: napiModule.waitThreadStart } ) + wasiThreads.patchWasiInstance(wasi) + + Object.assign( + importObject, + typeof wasi.getImportObject === 'function' + ? wasi.getImportObject() + : { wasi_snapshot_preview1: wasi.wasiImport } + ) + Object.assign(importObject, wasiThreads.getImportObject()) } @@ -127,13 +134,20 @@ function loadNapiModuleImpl (loadFn: Function, userNapiModule: NapiModule | unde instance = { exports } } const module = source.module + + const isCommand = ('_start' in originalExports) && (typeof originalExports._start === 'function') + if (wasi) { if (napiModule.childThread) { instance = createInstanceProxy(instance, memory) } wasiThreads!.setup(instance, module, memory) - if ('_start' in originalExports) { - wasi.start(instance) + if (isCommand) { + if (napiModule.childThread) { + wasi.start(instance) + } else { + setupInstance(wasi, instance) + } } else { wasi.initialize(instance) } @@ -155,7 +169,11 @@ function loadNapiModuleImpl (loadFn: Function, userNapiModule: NapiModule | unde table }) - const ret: any = { instance: originalInstance, module } + const ret: any = { + instance: originalInstance, + module, + usedInstance: instance + } if (!isLoad) { ret.napiModule = napiModule } @@ -192,7 +210,7 @@ export function loadNapiModule ( /** Only support `BufferSource` or `WebAssembly.Module` on Node.js */ wasmInput: InputType | Promise, options?: LoadOptions -): Promise { +): Promise { if (typeof napiModule !== 'object' || napiModule === null) { throw new TypeError('Invalid napiModule') } @@ -204,7 +222,7 @@ export function loadNapiModuleSync ( napiModule: NapiModule, wasmInput: BufferSource | WebAssembly.Module, options?: LoadOptions -): WebAssembly.WebAssemblyInstantiatedSource { +): LoadedSource { if (typeof napiModule !== 'object' || napiModule === null) { throw new TypeError('Invalid napiModule') } @@ -227,3 +245,18 @@ export function instantiateNapiModuleSync ( ): InstantiatedSource { return loadNapiModuleImpl(loadSyncCallback, undefined, wasmInput, options) } + +function setupInstance (wasi: WASIInstance, instance: WebAssembly.Instance): void { + const symbols = Object.getOwnPropertySymbols(wasi) + const selectDescription = (description: string) => (s: symbol) => { + if (s.description) { + return s.description === description + } + return s.toString() === `Symbol(${description})` + } + const kInstance = symbols.filter(selectDescription('kInstance'))[0] + const kSetMemory = symbols.filter(selectDescription('kSetMemory'))[0]; + + (wasi as any)[kInstance] = instance; + (wasi as any)[kSetMemory](instance.exports.memory) +} diff --git a/packages/wasi-threads/src/index.ts b/packages/wasi-threads/src/index.ts index 99051e75..8301c834 100644 --- a/packages/wasi-threads/src/index.ts +++ b/packages/wasi-threads/src/index.ts @@ -18,3 +18,5 @@ export { ThreadMessageHandler } from './worker' export type { InstantiatePayload, ThreadMessageHandlerOptions } from './worker' export { createInstanceProxy } from './proxy' + +export { isTrapError } from './util' diff --git a/packages/wasi-threads/src/thread-manager.ts b/packages/wasi-threads/src/thread-manager.ts index 7c45e581..f8a459e3 100644 --- a/packages/wasi-threads/src/thread-manager.ts +++ b/packages/wasi-threads/src/thread-manager.ts @@ -99,6 +99,8 @@ export class ThreadManager { if (worker.whenLoaded) return worker.whenLoaded const err = this.printErr const beforeLoad = this._beforeLoad + // eslint-disable-next-line @typescript-eslint/no-this-alias + const _this = this worker.whenLoaded = new Promise((resolve, reject) => { const handleError = function (e: { message: string }): void { let message = 'worker sent an error!' @@ -106,6 +108,11 @@ export class ThreadManager { message = 'worker (tid = ' + worker.__emnapi_tid + ') sent an error!' } err(message + ' ' + e.message) + if (e.message === 'unreachable') { + try { + _this.terminateAllThreads() + } catch (_) {} + } reject(e) throw e as Error } diff --git a/packages/wasi-threads/src/util.ts b/packages/wasi-threads/src/util.ts index dfd1474f..21f4b5f3 100644 --- a/packages/wasi-threads/src/util.ts +++ b/packages/wasi-threads/src/util.ts @@ -57,3 +57,8 @@ export function deserizeErrorFromBuffer (sab: SharedArrayBuffer): Error | null { }) return error } + +/** @public */ +export function isTrapError (e: Error): e is WebAssembly.RuntimeError { + return (e instanceof WebAssembly.RuntimeError) && (e.message === 'unreachable') +} diff --git a/packages/wasi-threads/src/wasi-threads.ts b/packages/wasi-threads/src/wasi-threads.ts index 485ad377..45460cb8 100644 --- a/packages/wasi-threads/src/wasi-threads.ts +++ b/packages/wasi-threads/src/wasi-threads.ts @@ -1,4 +1,4 @@ -import { ENVIRONMENT_IS_NODE, deserizeErrorFromBuffer, getPostMessage } from './util' +import { ENVIRONMENT_IS_NODE, deserizeErrorFromBuffer, getPostMessage, isTrapError } from './util' import { checkSharedWasmMemory, ThreadManager } from './thread-manager' import type { WorkerMessageEvent, ThreadManagerOptions } from './thread-manager' @@ -6,7 +6,7 @@ import type { WorkerMessageEvent, ThreadManagerOptions } from './thread-manager' export interface WASIInstance { readonly wasiImport?: Record initialize (instance: object): void - start (instance: object): void + start (instance: object): number getImportObject? (): any } @@ -244,16 +244,29 @@ export class WASIThreads { public patchWasiInstance (wasi: T): T { if (!wasi) return wasi + // eslint-disable-next-line @typescript-eslint/no-this-alias + const _this = this const wasiImport = wasi.wasiImport if (wasiImport) { const proc_exit = wasiImport.proc_exit - // eslint-disable-next-line @typescript-eslint/no-this-alias - const _this = this wasiImport.proc_exit = function (code: number): number { _this.terminateAllThreads() return proc_exit.call(this, code) } } + const start = wasi.start + if (typeof start === 'function') { + wasi.start = function (instance: object): number { + try { + return start.call(this, instance) + } catch (err) { + if (isTrapError(err)) { + _this.terminateAllThreads() + } + throw err + } + } + } return wasi } diff --git a/packages/wasi-threads/src/worker.ts b/packages/wasi-threads/src/worker.ts index 9d8b545a..181bb041 100644 --- a/packages/wasi-threads/src/worker.ts +++ b/packages/wasi-threads/src/worker.ts @@ -1,5 +1,5 @@ import type { WorkerMessageEvent } from './thread-manager' -import { getPostMessage, serizeErrorToBuffer } from './util' +import { getPostMessage, isTrapError, serizeErrorToBuffer } from './util' /** @public */ export interface InstantiatePayload { @@ -95,7 +95,18 @@ export class ThreadMessageHandler { const tid = payload.tid const startArg = payload.arg notifyPthreadCreateResult(payload.sab, 1) - ;(this.instance!.exports.wasi_thread_start as Function)(tid, startArg) + try { + (this.instance!.exports.wasi_thread_start as Function)(tid, startArg) + } catch (err) { + if (isTrapError(err)) { + postMessage({ + __emnapi__: { + type: 'terminate-all-threads' + } + }) + } + throw err + } postMessage({ __emnapi__: { type: 'cleanup-thread',