Skip to content

Commit ed09e7a

Browse files
authored
wasm: Embed rule names in a custom section (#510)
1 parent c8b6915 commit ed09e7a

File tree

3 files changed

+121
-23
lines changed

3 files changed

+121
-23
lines changed

packages/wasm/src/index.js

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,13 @@ class Compiler {
469469

470470
// The rule ID is a 0-based index that's mapped to the name.
471471
// It is *not* the same as the function index the rule's eval function.
472-
this.ruleIdByName = new Map(Object.keys(grammar.rules).map((name, i) => [name, i]));
472+
// Ensure that the default start rule always has id 0.
473+
this.ruleIdByName = new Map([[grammar.defaultStartRule, 0]]);
474+
for (const name of Object.keys(grammar.rules)) {
475+
if (name !== grammar.defaultStartRule) {
476+
this.ruleIdByName.set(name, this.ruleIdByName.size);
477+
}
478+
}
473479
}
474480

475481
ruleBody(ruleName, grammar = this.grammar) {
@@ -502,10 +508,6 @@ class Compiler {
502508
return w.funcidx(checkNotNull(this.ruleIdByName.get(name)) + offset);
503509
}
504510

505-
ruleNames() {
506-
return [...this.ruleIdByName.keys()];
507-
}
508-
509511
// Return an object implementing all of the debug imports.
510512
getDebugImports(log) {
511513
const ans = {};
@@ -576,10 +578,22 @@ class Compiler {
576578
return this.asm._functionDecls.at(-1);
577579
}
578580

581+
buildRuleNamesSection(ruleNames) {
582+
// A custom section that allows the clients to look up rule IDs by name.
583+
// They're simply encoded as a vec(name), and the client can turn this
584+
// into a list/array and use the ruleId as the index.
585+
return w.custom(w.name('ruleNames'), w.vec(ruleNames.map((n, i) => w.name(n))));
586+
}
587+
579588
buildModule(typeMap, functionDecls) {
580589
const {importDecls} = this;
581590
assert(this.importDecls.length === prebuilt.destImportCount, 'import count mismatch');
582591

592+
const ruleNames = [...this.ruleIdByName.keys()];
593+
594+
// Ensure that `ruleNames` is in the correct order.
595+
ruleNames.forEach((n, i) => assert(i === this.ruleIdByName.get(n)));
596+
583597
typeMap.addDecls(importDecls);
584598
typeMap.addDecls(functionDecls);
585599

@@ -613,7 +627,7 @@ class Compiler {
613627
const table = w.table(
614628
w.tabletype(w.elemtype.funcref, w.limits.minmax(numRules, numRules)),
615629
);
616-
const tableData = this.ruleNames().map(name => this.ruleEvalFuncIdx(name));
630+
const tableData = ruleNames.map(name => this.ruleEvalFuncIdx(name));
617631
assert(numRules === tableData.length, 'Invalid rule count');
618632

619633
// Determine the index of the start function.
@@ -633,6 +647,7 @@ class Compiler {
633647
w.startsec(w.start(startFuncidx)),
634648
w.elemsec([w.elem(w.tableidx(0), [instr.i32.const, w.i32(0), instr.end], tableData)]),
635649
mergeSections(w.SECTION_ID_CODE, prebuilt.codesec, codes),
650+
w.customsec(this.buildRuleNamesSection(ruleNames)),
636651
]);
637652
const bytes = Uint8Array.from(mod.flat(Infinity));
638653

packages/wasm/test/data/_es5.wasm

2.74 KB
Binary file not shown.

packages/wasm/test/go/matcher.go

Lines changed: 100 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package main
22

33
import (
44
"context"
5+
"encoding/binary"
56
"fmt"
7+
"io"
68
"os"
79

810
"github.com/tetratelabs/wazero"
@@ -11,7 +13,7 @@ import (
1113

1214
// Constants for memory layout
1315
const (
14-
wasmPageSize = 64 * 1024
16+
wasmPageSize = 64 * 1024
1517
InputBufferOffset = wasmPageSize
1618
InputBufferSize = wasmPageSize
1719
MemoTableOffset = InputBufferOffset + InputBufferSize
@@ -35,15 +37,72 @@ func (m *WasmMatcher) GetModule() api.Module {
3537
}
3638

3739
func NewWasmMatcher(ctx context.Context) *WasmMatcher {
40+
// Create a new runtime with custom sections enabled
41+
config := wazero.NewRuntimeConfig().WithCustomSections(true)
42+
3843
return &WasmMatcher{
39-
runtime: wazero.NewRuntime(ctx),
44+
runtime: wazero.NewRuntimeWithConfig(ctx, config),
4045
ctx: ctx,
4146
ruleIds: make(map[string]int),
4247
pos: 0,
4348
lastMatchResult: false,
4449
}
4550
}
4651

52+
// parseRuleNames parses the rule names from the custom section data
53+
// The data is formatted as a WebAssembly vector of strings (each string is a length-prefixed UTF-8 bytes)
54+
// with LEB128-encoded lengths
55+
func parseRuleNames(data []byte) ([]string, error) {
56+
if len(data) == 0 {
57+
return nil, fmt.Errorf("empty custom section data")
58+
}
59+
60+
// Read the number of names (vec length) as LEB128-encoded uint32
61+
numNamesUint64, bytesRead := binary.Uvarint(data)
62+
if bytesRead <= 0 {
63+
return nil, fmt.Errorf("failed to read number of names: %v", io.ErrUnexpectedEOF)
64+
}
65+
66+
// Ensure the value fits in uint32
67+
if numNamesUint64 > uint64(^uint32(0)) {
68+
return nil, fmt.Errorf("number of names exceeds maximum uint32 value")
69+
}
70+
71+
numNames := uint32(numNamesUint64)
72+
data = data[bytesRead:]
73+
74+
names := make([]string, numNames)
75+
for i := uint32(0); i < numNames; i++ {
76+
// Read the length of the name as LEB128-encoded uint32
77+
nameLenUint64, bytesRead := binary.Uvarint(data)
78+
if bytesRead <= 0 {
79+
return nil, fmt.Errorf("failed to read name length: %v", io.ErrUnexpectedEOF)
80+
}
81+
82+
// Ensure the value fits in uint32
83+
if nameLenUint64 > uint64(^uint32(0)) {
84+
return nil, fmt.Errorf("name length exceeds maximum uint32 value")
85+
}
86+
87+
nameLen := uint32(nameLenUint64)
88+
data = data[bytesRead:]
89+
90+
// Ensure we have enough bytes to read
91+
if uint64(len(data)) < uint64(nameLen) {
92+
return nil, fmt.Errorf("buffer too small to read name bytes")
93+
}
94+
95+
// Read the name bytes
96+
nameBytes := data[:nameLen]
97+
data = data[nameLen:]
98+
99+
// Convert to string
100+
names[i] = string(nameBytes)
101+
}
102+
103+
return names, nil
104+
}
105+
47106
func (m *WasmMatcher) LoadModule(wasmPath string) error {
48107
// Read the WASM file
49108
wasmBytes, err := os.ReadFile(wasmPath)
@@ -74,27 +133,51 @@ func (m *WasmMatcher) LoadModule(wasmPath string) error {
74133
return fmt.Errorf("failed to create host module: %v", err)
75134
}
76135

77-
// Instantiate the module
78-
m.module, err = m.runtime.Instantiate(m.ctx, wasmBytes)
136+
// First compile the module to access the custom sections
137+
compiledModule, err := m.runtime.CompileModule(m.ctx, wasmBytes)
79138
if err != nil {
80-
return fmt.Errorf("error instantiating module: %v", err)
139+
return fmt.Errorf("error compiling module: %v", err)
81140
}
82141

83-
// Extract rule IDs if this is a grammar module
84-
rulesFunc := m.module.ExportedFunction("getRuleIds")
85-
if rulesFunc != nil {
86-
// In a real implementation, you would actually extract the rule IDs
87-
// by calling the exported function and reading the results
142+
// Get all custom sections from the module
143+
customSections := compiledModule.CustomSections()
144+
if customSections == nil {
145+
return fmt.Errorf("no custom sections found in module")
146+
}
88147

89-
// For now, just populate with some example rule IDs
90-
m.ruleIds = map[string]int{
91-
"Start": 0,
92-
"Expr": 1,
93-
"Term": 2,
94-
"Factor": 3,
148+
var ruleNamesSection api.CustomSection
149+
for _, section := range customSections {
150+
if section.Name() == "ruleNames" {
151+
ruleNamesSection = section
152+
break
95153
}
154+
}
155+
156+
if ruleNamesSection == nil {
157+
return fmt.Errorf("required custom section 'ruleNames' not found")
158+
}
159+
160+
// Parse rule names from the custom section data
161+
ruleNames, err := parseRuleNames(ruleNamesSection.Data())
162+
if err != nil {
163+
return fmt.Errorf("failed to parse rule names from custom section: %v", err)
164+
}
165+
166+
// Now instantiate the module
167+
m.module, err = m.runtime.InstantiateModule(m.ctx, compiledModule, wazero.NewModuleConfig())
168+
if err != nil {
169+
return fmt.Errorf("error instantiating module: %v", err)
170+
}
171+
172+
// Build the ruleIds map (mapping from name to index)
173+
m.ruleIds = make(map[string]int, len(ruleNames))
174+
for i, name := range ruleNames {
175+
m.ruleIds[name] = i
176+
}
96177

97-
m.defaultStartRule = "Start"
178+
// Set the default start rule to the first rule
179+
if len(ruleNames) > 0 {
180+
m.defaultStartRule = ruleNames[0]
98181
}
99182

100183
return nil

0 commit comments

Comments
 (0)