@@ -9,9 +9,11 @@ const store = @import("../../types/store.zig");
9
9
10
10
const Ident = base .Ident ;
11
11
12
+ const MkSafeList = collections .SafeList ;
12
13
const exitOnOutOfMemory = collections .utils .exitOnOom ;
13
14
14
15
const Store = store .Store ;
16
+ const DescStoreIdx = store .DescStoreIdx ;
15
17
16
18
const Desc = types .Descriptor ;
17
19
const Content = types .Content ;
@@ -22,9 +24,18 @@ const Tag = types.Tag;
22
24
23
25
/// Check if a variable is recursive
24
26
///
25
- /// This function uses `Scratch` as to hold intermediate values. It resets it
26
- /// before running.
27
- pub fn occurs (types_store : * const Store , scratch : * Scratch , var_ : Var ) bool {
27
+ /// This uses `Scratch` as to hold intermediate values. `occurs` will reset it
28
+ /// before each run.
29
+ ///
30
+ /// This function accepts a mutable reference to `Store`, but guarantees that it
31
+ /// _only_ modifies a variable's `Mark`. Before returning, all visited nodes'
32
+ /// `Mark`s will be reset to `none`.
33
+ ///
34
+ /// TODO: See if there's a way to represent this ^ in the type system? If we
35
+ /// switch the types_store descriptors to use a multi list (which we should do
36
+ /// anyway), maybe we can only pass in only a mutable ref to the backing `Mark`s
37
+ /// array?
38
+ pub fn occurs (types_store : * Store , scratch : * Scratch , var_ : Var ) bool {
28
39
scratch .reset ();
29
40
30
41
var result = false ;
@@ -36,20 +47,27 @@ pub fn occurs(types_store: *const Store, scratch: *Scratch, var_: Var) bool {
36
47
},
37
48
};
38
49
50
+ for (scratch .visited .items .items [0.. ]) | visited_desc_idx | {
51
+ types_store .setDescMark (visited_desc_idx , Mark .none );
52
+ }
53
+
39
54
return result ;
40
55
}
41
56
42
57
/// This is an intermediate struct used when checking occurrences.
43
58
const CheckOccurs = struct {
44
59
const Self = @This ();
45
60
46
- types_store : * const Store ,
61
+ types_store : * Store ,
47
62
scratch : * Scratch ,
48
63
49
64
/// Init CheckOccurs
50
65
///
51
66
/// Note that this struct does not own any of it's fields
52
- fn init (types_store : * const Store , scratch : * Scratch ) Self {
67
+ ///
68
+ /// This function accepts a mutable reference to `Store`, and _must_ only
69
+ /// modify a var's `Mark`
70
+ fn init (types_store : * Store , scratch : * Scratch ) Self {
53
71
return .{ .types_store = types_store , .scratch = scratch };
54
72
}
55
73
@@ -59,15 +77,14 @@ const CheckOccurs = struct {
59
77
fn occurs (self : * Self , var_ : Var ) error {Occurs }! void {
60
78
const root = self .types_store .resolveVar (var_ );
61
79
62
- if (self .scratch .hasSeenVar (root .var_ )) {
63
- // If we've already seen this var, then it's recursive
64
- return error .Occurs ;
65
- } else if (self .scratch .hasVisitedVar (root .var_ )) {
80
+ if (root .desc .mark == .visited ) {
66
81
// If we've already visited this var and not errored, then it's not recursive
67
82
return ;
83
+ } else if (self .scratch .hasSeenVar (root .var_ )) {
84
+ // If we've already seen this var, then it's recursive
85
+ return error .Occurs ;
68
86
} else {
69
87
self .scratch .appendSeen (var_ );
70
-
71
88
switch (root .desc .content ) {
72
89
.structure = > | flat_type | {
73
90
switch (flat_type ) {
@@ -121,9 +138,10 @@ const CheckOccurs = struct {
121
138
.pure = > {},
122
139
.err = > {},
123
140
}
124
-
125
141
self .scratch .popSeen ();
126
- self .scratch .appendVisited (var_ );
142
+
143
+ self .scratch .appendVisited (root .desc_idx );
144
+ self .types_store .setDescMark (root .desc_idx , Mark .visited );
127
145
}
128
146
}
129
147
@@ -154,8 +172,8 @@ const Scratch = struct {
154
172
gpa : std.mem.Allocator ,
155
173
156
174
seen : Var.SafeList ,
157
- visited : Var.SafeList ,
158
175
err_chain : Var.SafeList ,
176
+ visited : MkSafeList (DescStoreIdx ),
159
177
160
178
fn init (gpa : std.mem.Allocator ) Self {
161
179
// TODO: eventually use herusitics here to determine sensible defaults
@@ -164,21 +182,21 @@ const Scratch = struct {
164
182
return .{
165
183
.gpa = gpa ,
166
184
.seen = Var .SafeList .initCapacity (gpa , 32 ),
167
- .visited = Var .SafeList .initCapacity (gpa , 32 ),
168
185
.err_chain = Var .SafeList .initCapacity (gpa , 32 ),
186
+ .visited = MkSafeList (DescStoreIdx ).initCapacity (gpa , 64 ),
169
187
};
170
188
}
171
189
172
190
fn deinit (self : * Self ) void {
173
191
self .seen .deinit (self .gpa );
174
- self .visited .deinit (self .gpa );
175
192
self .err_chain .deinit (self .gpa );
193
+ self .visited .deinit (self .gpa );
176
194
}
177
195
178
196
fn reset (self : * Self ) void {
179
197
self .seen .items .clearRetainingCapacity ();
180
- self .visited .items .clearRetainingCapacity ();
181
198
self .err_chain .items .clearRetainingCapacity ();
199
+ self .visited .items .clearRetainingCapacity ();
182
200
}
183
201
184
202
fn hasSeenVar (self : * const Self , var_ : Var ) bool {
@@ -196,15 +214,8 @@ const Scratch = struct {
196
214
_ = self .seen .items .pop ();
197
215
}
198
216
199
- fn hasVisitedVar (self : * const Self , var_ : Var ) bool {
200
- for (self .visited .items .items ) | visited_var | {
201
- if (visited_var == var_ ) return true ;
202
- }
203
- return false ;
204
- }
205
-
206
- fn appendVisited (self : * Self , var_ : Var ) void {
207
- _ = self .visited .append (self .gpa , var_ );
217
+ fn appendVisited (self : * Self , desc_idx : DescStoreIdx ) void {
218
+ _ = self .visited .append (self .gpa , desc_idx );
208
219
}
209
220
210
221
fn appendErrChain (self : * Self , var_ : Var ) void {
@@ -447,6 +458,10 @@ test "occurs: recursive tag union (v = TagUnion { Foo(v) } with ext = v)" {
447
458
const err_chain = scratch .errChainSlice ();
448
459
try std .testing .expectEqual (1 , err_chain .len );
449
460
try std .testing .expectEqual (linked_list , err_chain [0 ]);
461
+
462
+ for (scratch .visited .items .items [0.. ]) | visited_desc_idx | {
463
+ try std .testing .expectEqual (Mark .none , types_store .getDesc (visited_desc_idx ).mark );
464
+ }
450
465
}
451
466
452
467
test "occurs: nested recursive tag union (v = TagUnion { Cons(elem, Box(v)) } )" {
@@ -486,4 +501,8 @@ test "occurs: nested recursive tag union (v = TagUnion { Cons(elem, Box(v)) } )"
486
501
try std .testing .expect (err_chain .len == 2 );
487
502
try std .testing .expectEqual (err_chain [0 ], boxed_linked_list );
488
503
try std .testing .expectEqual (err_chain [1 ], linked_list );
504
+
505
+ for (scratch .visited .items .items [0.. ]) | visited_desc_idx | {
506
+ try std .testing .expectEqual (Mark .none , types_store .getDesc (visited_desc_idx ).mark );
507
+ }
489
508
}
0 commit comments