Skip to content

Commit

Permalink
improve threads termination
Browse files Browse the repository at this point in the history
  • Loading branch information
toyobayashi committed May 11, 2024
1 parent 0fa53a7 commit 0de2556
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 19 deletions.
1 change: 1 addition & 0 deletions packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ export type {
export type {
LoadOptions,
InstantiateOptions,
LoadedSource,
InstantiatedSource
} from './load'

Expand Down
59 changes: 46 additions & 13 deletions packages/core/src/load.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ import { createNapiModule } from './emnapi/index'
import type { CreateOptions, NapiModule } from './emnapi/index'

/** @public */
export interface InstantiatedSource extends WebAssembly.WebAssemblyInstantiatedSource {
export interface LoadedSource extends WebAssembly.WebAssemblyInstantiatedSource {
usedInstance: WebAssembly.Instance
}

/** @public */
export interface InstantiatedSource extends LoadedSource {
napiModule: NapiModule
}

Expand Down Expand Up @@ -70,13 +75,6 @@ function loadNapiModuleImpl (loadFn: Function, userNapiModule: NapiModule | unde
}

if (wasi) {
Object.assign(
importObject,
typeof wasi.getImportObject === 'function'
? wasi.getImportObject()
: { wasi_snapshot_preview1: wasi.wasiImport }
)

wasiThreads = new WASIThreads(
napiModule.childThread
? {
Expand All @@ -88,6 +86,15 @@ function loadNapiModuleImpl (loadFn: Function, userNapiModule: NapiModule | unde
waitThreadStart: napiModule.waitThreadStart
}
)
wasiThreads.patchWasiInstance(wasi)

Object.assign(
importObject,
typeof wasi.getImportObject === 'function'
? wasi.getImportObject()
: { wasi_snapshot_preview1: wasi.wasiImport }
)

Object.assign(importObject, wasiThreads.getImportObject())
}

Expand Down Expand Up @@ -127,13 +134,20 @@ function loadNapiModuleImpl (loadFn: Function, userNapiModule: NapiModule | unde
instance = { exports }
}
const module = source.module

const isCommand = ('_start' in originalExports) && (typeof originalExports._start === 'function')

if (wasi) {
if (napiModule.childThread) {
instance = createInstanceProxy(instance, memory)
}
wasiThreads!.setup(instance, module, memory)
if ('_start' in originalExports) {
wasi.start(instance)
if (isCommand) {
if (napiModule.childThread) {
wasi.start(instance)
} else {
setupInstance(wasi, instance)
}
} else {
wasi.initialize(instance)
}
Expand All @@ -155,7 +169,11 @@ function loadNapiModuleImpl (loadFn: Function, userNapiModule: NapiModule | unde
table
})

const ret: any = { instance: originalInstance, module }
const ret: any = {
instance: originalInstance,
module,
usedInstance: instance
}
if (!isLoad) {
ret.napiModule = napiModule
}
Expand Down Expand Up @@ -192,7 +210,7 @@ export function loadNapiModule (
/** Only support `BufferSource` or `WebAssembly.Module` on Node.js */
wasmInput: InputType | Promise<InputType>,
options?: LoadOptions
): Promise<WebAssembly.WebAssemblyInstantiatedSource> {
): Promise<LoadedSource> {
if (typeof napiModule !== 'object' || napiModule === null) {
throw new TypeError('Invalid napiModule')
}
Expand All @@ -204,7 +222,7 @@ export function loadNapiModuleSync (
napiModule: NapiModule,
wasmInput: BufferSource | WebAssembly.Module,
options?: LoadOptions
): WebAssembly.WebAssemblyInstantiatedSource {
): LoadedSource {
if (typeof napiModule !== 'object' || napiModule === null) {
throw new TypeError('Invalid napiModule')
}
Expand All @@ -227,3 +245,18 @@ export function instantiateNapiModuleSync (
): InstantiatedSource {
return loadNapiModuleImpl(loadSyncCallback, undefined, wasmInput, options)
}

function setupInstance (wasi: WASIInstance, instance: WebAssembly.Instance): void {
const symbols = Object.getOwnPropertySymbols(wasi)
const selectDescription = (description: string) => (s: symbol) => {
if (s.description) {
return s.description === description
}
return s.toString() === `Symbol(${description})`
}
const kInstance = symbols.filter(selectDescription('kInstance'))[0]
const kSetMemory = symbols.filter(selectDescription('kSetMemory'))[0];

(wasi as any)[kInstance] = instance;
(wasi as any)[kSetMemory](instance.exports.memory)
}
2 changes: 2 additions & 0 deletions packages/wasi-threads/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ export { ThreadMessageHandler } from './worker'
export type { InstantiatePayload, ThreadMessageHandlerOptions } from './worker'

export { createInstanceProxy } from './proxy'

export { isTrapError } from './util'
7 changes: 7 additions & 0 deletions packages/wasi-threads/src/thread-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,20 @@ export class ThreadManager {
if (worker.whenLoaded) return worker.whenLoaded
const err = this.printErr
const beforeLoad = this._beforeLoad
// eslint-disable-next-line @typescript-eslint/no-this-alias
const _this = this
worker.whenLoaded = new Promise<WorkerLike>((resolve, reject) => {
const handleError = function (e: { message: string }): void {
let message = 'worker sent an error!'
if (worker.__emnapi_tid !== undefined) {
message = 'worker (tid = ' + worker.__emnapi_tid + ') sent an error!'
}
err(message + ' ' + e.message)
if (e.message === 'unreachable') {
try {
_this.terminateAllThreads()
} catch (_) {}
}
reject(e)
throw e as Error
}
Expand Down
5 changes: 5 additions & 0 deletions packages/wasi-threads/src/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,8 @@ export function deserizeErrorFromBuffer (sab: SharedArrayBuffer): Error | null {
})
return error
}

/** @public */
export function isTrapError (e: Error): e is WebAssembly.RuntimeError {
return (e instanceof WebAssembly.RuntimeError) && (e.message === 'unreachable')
}
21 changes: 17 additions & 4 deletions packages/wasi-threads/src/wasi-threads.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { ENVIRONMENT_IS_NODE, deserizeErrorFromBuffer, getPostMessage } from './util'
import { ENVIRONMENT_IS_NODE, deserizeErrorFromBuffer, getPostMessage, isTrapError } from './util'
import { checkSharedWasmMemory, ThreadManager } from './thread-manager'
import type { WorkerMessageEvent, ThreadManagerOptions } from './thread-manager'

/** @public */
export interface WASIInstance {
readonly wasiImport?: Record<string, any>
initialize (instance: object): void
start (instance: object): void
start (instance: object): number
getImportObject? (): any
}

Expand Down Expand Up @@ -244,16 +244,29 @@ export class WASIThreads {

public patchWasiInstance<T extends WASIInstance> (wasi: T): T {
if (!wasi) return wasi
// eslint-disable-next-line @typescript-eslint/no-this-alias
const _this = this
const wasiImport = wasi.wasiImport
if (wasiImport) {
const proc_exit = wasiImport.proc_exit
// eslint-disable-next-line @typescript-eslint/no-this-alias
const _this = this
wasiImport.proc_exit = function (code: number): number {
_this.terminateAllThreads()
return proc_exit.call(this, code)
}
}
const start = wasi.start
if (typeof start === 'function') {
wasi.start = function (instance: object): number {
try {
return start.call(this, instance)
} catch (err) {
if (isTrapError(err)) {
_this.terminateAllThreads()
}
throw err
}
}
}
return wasi
}

Expand Down
15 changes: 13 additions & 2 deletions packages/wasi-threads/src/worker.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { WorkerMessageEvent } from './thread-manager'
import { getPostMessage, serizeErrorToBuffer } from './util'
import { getPostMessage, isTrapError, serizeErrorToBuffer } from './util'

/** @public */
export interface InstantiatePayload {
Expand Down Expand Up @@ -95,7 +95,18 @@ export class ThreadMessageHandler {
const tid = payload.tid
const startArg = payload.arg
notifyPthreadCreateResult(payload.sab, 1)
;(this.instance!.exports.wasi_thread_start as Function)(tid, startArg)
try {
(this.instance!.exports.wasi_thread_start as Function)(tid, startArg)
} catch (err) {
if (isTrapError(err)) {
postMessage({
__emnapi__: {
type: 'terminate-all-threads'
}
})
}
throw err
}
postMessage({
__emnapi__: {
type: 'cleanup-thread',
Expand Down

0 comments on commit 0de2556

Please sign in to comment.