@@ -3,8 +3,10 @@ package host
33import (
44 "bytes"
55 "context"
6+ "crypto/sha256"
67 "encoding/base64"
78 "encoding/binary"
9+ "encoding/hex"
810 "encoding/json"
911 "errors"
1012 "fmt"
@@ -119,6 +121,14 @@ type Module struct {
119121 stopCh chan struct {}
120122}
121123
124+ type WasmBinaryStore interface {
125+ // GetSerialisedModulePath returns the path to the serialised module for the given workflowID. If the module does not exist, exists
126+ // will be false.
127+ GetSerialisedModulePath (workflowID string ) (path string , exists bool , err error )
128+ StoreSerialisedModule (workflowID string , binaryID string , serialisedModule []byte ) error
129+ GetWasmBinary (ctx context.Context , workflowID string ) ([]byte , error )
130+ }
131+
122132// WithDeterminism sets the Determinism field to a deterministic seed from a known time.
123133//
124134// "The Times 03/Jan/2009 Chancellor on brink of second bailout for banks"
@@ -133,7 +143,10 @@ func WithDeterminism() func(*ModuleConfig) {
133143 }
134144}
135145
136- func NewModule (modCfg * ModuleConfig , binary []byte , opts ... func (* ModuleConfig )) (* Module , error ) {
146+ // NewModule uses the WasmBinaryStore to load the module for a given workflowID. If the module is available as a serialised
147+ // representation from the WasmBinaryStore that will be used, else it will be loaded from the wasm binary.
148+ func NewModule (ctx context.Context , lggr logger.Logger , modCfg * ModuleConfig , workflowID string , wasmStore WasmBinaryStore ,
149+ opts ... func (* ModuleConfig )) (* Module , error ) {
137150 // Apply options to the module config.
138151 for _ , opt := range opts {
139152 opt (modCfg )
@@ -193,37 +206,33 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig))
193206
194207 cfg .CacheConfigLoadDefault ()
195208 cfg .SetCraneliftOptLevel (wasmtime .OptLevelSpeedAndSize )
196-
209+
197210 // Load testing shows that leaving native unwind info enabled causes a very large slowdown when loading multiple modules.
198211 cfg .SetNativeUnwindInfo (false )
199212
200213 engine := wasmtime .NewEngineWithConfig (cfg )
201- if ! modCfg .IsUncompressed {
202- // validate the binary size before decompressing
203- // this is to prevent decompression bombs
204- if uint64 (len (binary )) > modCfg .MaxCompressedBinarySize {
205- return nil , fmt .Errorf ("compressed binary size exceeds the maximum allowed size of %d bytes" , modCfg .MaxCompressedBinarySize )
206- }
207214
208- rdr := io .LimitReader (brotli .NewReader (bytes .NewBuffer (binary )), int64 (modCfg .MaxDecompressedBinarySize + 1 ))
209- decompedBinary , err := io .ReadAll (rdr )
210- if err != nil {
211- return nil , fmt .Errorf ("failed to decompress binary: %w" , err )
212- }
213-
214- binary = decompedBinary
215+ var mod * wasmtime.Module
216+ serialisedModulePath , exists , err := wasmStore .GetSerialisedModulePath (workflowID )
217+ if err != nil {
218+ return nil , fmt .Errorf ("error getting serialised module: %w" , err )
215219 }
216220
217- // Validate the decompressed binary size.
218- // io.LimitReader prevents decompression bombs by reading up to a set limit, but it will not return an error if the limit is reached.
219- // The Read() method will return io.EOF, and ReadAll will gracefully handle it and return nil.
220- if uint64 (len (binary )) > modCfg .MaxDecompressedBinarySize {
221- return nil , fmt .Errorf ("decompressed binary size reached the maximum allowed size of %d bytes" , modCfg .MaxDecompressedBinarySize )
221+ if exists {
222+ mod , err = wasmtime .NewModuleDeserializeFile (engine , serialisedModulePath )
223+ if err != nil {
224+ // It's possible that an error occurred because the module was serialised with a different engine configuration or
225+ // wasmtime version so the error is ignored and the code falls back to loading it from the wasm binary.
226+ lggr .Debugw ("error deserializing module, attempting to load from binary" , "workflowID" , workflowID , "error" , err )
227+ }
222228 }
223229
224- mod , err := wasmtime .NewModule (engine , binary )
225- if err != nil {
226- return nil , fmt .Errorf ("error creating wasmtime module: %w" , err )
230+ // If the serialized module was not found or deserialization failed, load the module from the wasm binary.
231+ if mod == nil {
232+ mod , err = loadModuleFromWasmBinary (ctx , lggr , modCfg , workflowID , wasmStore , engine )
233+ if err != nil {
234+ return nil , fmt .Errorf ("error loading module from wasm binary: %w" , err )
235+ }
227236 }
228237
229238 linker , err := newWasiLinker (modCfg , engine )
@@ -287,6 +296,70 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig))
287296 return m , nil
288297}
289298
299+ func loadModuleFromWasmBinary (ctx context.Context , lggr logger.Logger , modCfg * ModuleConfig , workflowID string , wasmStore WasmBinaryStore , engine * wasmtime.Engine ) (* wasmtime.Module , error ) {
300+ // Loading from the module binary is relatively very slow (~100 times slower than deserialization) so log the
301+ // time here to make it obvious when this is happening as it will impact workflow startup time.
302+ wasmBinary , err := wasmStore .GetWasmBinary (ctx , workflowID )
303+ if err != nil {
304+ return nil , fmt .Errorf ("error getting workflow binary: %w" , err )
305+ }
306+
307+ hash := sha256 .Sum256 (wasmBinary )
308+ binaryID := hex .EncodeToString (hash [:])
309+
310+ lggr .Infow ("loading module from binary" , "workflowID" , workflowID )
311+ mod , err := newModuleFromBinary (wasmBinary , modCfg , engine )
312+ if err != nil {
313+ return nil , fmt .Errorf ("error creating new module from wasm binary: %w" , err )
314+ }
315+ lggr .Infow ("finished loading module from binary" , "workflowID" , workflowID )
316+
317+ // Store the serialised module for future use.
318+ serialisedMod , err := mod .Serialize ()
319+ if err != nil {
320+ return nil , fmt .Errorf ("error serialising module: %w" , err )
321+ }
322+
323+ err = wasmStore .StoreSerialisedModule (workflowID , binaryID , serialisedMod )
324+ if err != nil {
325+ return nil , fmt .Errorf ("error storing serialised module: %w" , err )
326+ }
327+ return mod , nil
328+ }
329+
330+ func newModuleFromBinary (wasmBinary []byte , modCfg * ModuleConfig , engine * wasmtime.Engine ) (* wasmtime.Module , error ) {
331+
332+ if ! modCfg .IsUncompressed {
333+ // validate the binary size before decompressing
334+ // this is to prevent decompression bombs
335+ if uint64 (len (wasmBinary )) > modCfg .MaxCompressedBinarySize {
336+ return nil , fmt .Errorf ("compressed binary size exceeds the maximum allowed size of %d bytes" , modCfg .MaxCompressedBinarySize )
337+ }
338+
339+ rdr := io .LimitReader (brotli .NewReader (bytes .NewBuffer (wasmBinary )), int64 (modCfg .MaxDecompressedBinarySize + 1 ))
340+ decompedBinary , err := io .ReadAll (rdr )
341+ if err != nil {
342+ return nil , fmt .Errorf ("failed to decompress binary: %w" , err )
343+ }
344+
345+ wasmBinary = decompedBinary
346+ }
347+
348+ // Validate the decompressed binary size.
349+ // io.LimitReader prevents decompression bombs by reading up to a set limit, but it will not return an error if the limit is reached.
350+ // The Read() method will return io.EOF, and ReadAll will gracefully handle it and return nil.
351+ if uint64 (len (wasmBinary )) > modCfg .MaxDecompressedBinarySize {
352+ return nil , fmt .Errorf ("decompressed binary size reached the maximum allowed size of %d bytes" , modCfg .MaxDecompressedBinarySize )
353+ }
354+
355+ mod , err := wasmtime .NewModule (engine , wasmBinary )
356+ if err != nil {
357+ return nil , fmt .Errorf ("error creating wasmtime module: %w" , err )
358+ }
359+
360+ return mod , nil
361+ }
362+
290363func (m * Module ) Start () {
291364 m .wg .Add (1 )
292365 go func () {
0 commit comments