Skip to content

Commit c03d11e

Browse files
committed
Switch back to Marks method of tracking visited
This change switches back from the "new" way to keep track of visited node in the occurs check to set Marks, like the rust compiler does.
1 parent d0b1c71 commit c03d11e

File tree

3 files changed

+52
-27
lines changed

3 files changed

+52
-27
lines changed

src/check/check_types/occurs.zig

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@ const store = @import("../../types/store.zig");
99

1010
const Ident = base.Ident;
1111

12+
const MkSafeList = collections.SafeList;
1213
const exitOnOutOfMemory = collections.utils.exitOnOom;
1314

1415
const Store = store.Store;
16+
const DescStoreIdx = store.DescStoreIdx;
1517

1618
const Desc = types.Descriptor;
1719
const Content = types.Content;
@@ -22,9 +24,18 @@ const Tag = types.Tag;
2224

2325
/// Check if a variable is recursive
2426
///
25-
/// This function uses `Scratch` as to hold intermediate values. It resets it
26-
/// before running.
27-
pub fn occurs(types_store: *const Store, scratch: *Scratch, var_: Var) bool {
27+
/// This uses `Scratch` as to hold intermediate values. `occurs` will reset it
28+
/// before each run.
29+
///
30+
/// This function accepts a mutable reference to `Store`, but guarantees that it
31+
/// _only_ modifies a variable's `Mark`. Before returning, all visited nodes'
32+
/// `Mark`s will be reset to `none`.
33+
///
34+
/// TODO: See if there's a way to represent this ^ in the type system? If we
35+
/// switch the types_store descriptors to use a multi list (which we should do
36+
/// anyway), maybe we can only pass in only a mutable ref to the backing `Mark`s
37+
/// array?
38+
pub fn occurs(types_store: *Store, scratch: *Scratch, var_: Var) bool {
2839
scratch.reset();
2940

3041
var result = false;
@@ -36,20 +47,27 @@ pub fn occurs(types_store: *const Store, scratch: *Scratch, var_: Var) bool {
3647
},
3748
};
3849

50+
for (scratch.visited.items.items[0..]) |visited_desc_idx| {
51+
types_store.setDescMark(visited_desc_idx, Mark.none);
52+
}
53+
3954
return result;
4055
}
4156

4257
/// This is an intermediate struct used when checking occurrences.
4358
const CheckOccurs = struct {
4459
const Self = @This();
4560

46-
types_store: *const Store,
61+
types_store: *Store,
4762
scratch: *Scratch,
4863

4964
/// Init CheckOccurs
5065
///
5166
/// Note that this struct does not own any of it's fields
52-
fn init(types_store: *const Store, scratch: *Scratch) Self {
67+
///
68+
/// This function accepts a mutable reference to `Store`, and _must_ only
69+
/// modify a var's `Mark`
70+
fn init(types_store: *Store, scratch: *Scratch) Self {
5371
return .{ .types_store = types_store, .scratch = scratch };
5472
}
5573

@@ -59,15 +77,14 @@ const CheckOccurs = struct {
5977
fn occurs(self: *Self, var_: Var) error{Occurs}!void {
6078
const root = self.types_store.resolveVar(var_);
6179

62-
if (self.scratch.hasSeenVar(root.var_)) {
63-
// If we've already seen this var, then it's recursive
64-
return error.Occurs;
65-
} else if (self.scratch.hasVisitedVar(root.var_)) {
80+
if (root.desc.mark == .visited) {
6681
// If we've already visited this var and not errored, then it's not recursive
6782
return;
83+
} else if (self.scratch.hasSeenVar(root.var_)) {
84+
// If we've already seen this var, then it's recursive
85+
return error.Occurs;
6886
} else {
6987
self.scratch.appendSeen(var_);
70-
7188
switch (root.desc.content) {
7289
.structure => |flat_type| {
7390
switch (flat_type) {
@@ -121,9 +138,10 @@ const CheckOccurs = struct {
121138
.pure => {},
122139
.err => {},
123140
}
124-
125141
self.scratch.popSeen();
126-
self.scratch.appendVisited(var_);
142+
143+
self.scratch.appendVisited(root.desc_idx);
144+
self.types_store.setDescMark(root.desc_idx, Mark.visited);
127145
}
128146
}
129147

@@ -154,8 +172,8 @@ const Scratch = struct {
154172
gpa: std.mem.Allocator,
155173

156174
seen: Var.SafeList,
157-
visited: Var.SafeList,
158175
err_chain: Var.SafeList,
176+
visited: MkSafeList(DescStoreIdx),
159177

160178
fn init(gpa: std.mem.Allocator) Self {
161179
// TODO: eventually use herusitics here to determine sensible defaults
@@ -164,21 +182,21 @@ const Scratch = struct {
164182
return .{
165183
.gpa = gpa,
166184
.seen = Var.SafeList.initCapacity(gpa, 32),
167-
.visited = Var.SafeList.initCapacity(gpa, 32),
168185
.err_chain = Var.SafeList.initCapacity(gpa, 32),
186+
.visited = MkSafeList(DescStoreIdx).initCapacity(gpa, 64),
169187
};
170188
}
171189

172190
fn deinit(self: *Self) void {
173191
self.seen.deinit(self.gpa);
174-
self.visited.deinit(self.gpa);
175192
self.err_chain.deinit(self.gpa);
193+
self.visited.deinit(self.gpa);
176194
}
177195

178196
fn reset(self: *Self) void {
179197
self.seen.items.clearRetainingCapacity();
180-
self.visited.items.clearRetainingCapacity();
181198
self.err_chain.items.clearRetainingCapacity();
199+
self.visited.items.clearRetainingCapacity();
182200
}
183201

184202
fn hasSeenVar(self: *const Self, var_: Var) bool {
@@ -196,15 +214,8 @@ const Scratch = struct {
196214
_ = self.seen.items.pop();
197215
}
198216

199-
fn hasVisitedVar(self: *const Self, var_: Var) bool {
200-
for (self.visited.items.items) |visited_var| {
201-
if (visited_var == var_) return true;
202-
}
203-
return false;
204-
}
205-
206-
fn appendVisited(self: *Self, var_: Var) void {
207-
_ = self.visited.append(self.gpa, var_);
217+
fn appendVisited(self: *Self, desc_idx: DescStoreIdx) void {
218+
_ = self.visited.append(self.gpa, desc_idx);
208219
}
209220

210221
fn appendErrChain(self: *Self, var_: Var) void {
@@ -447,6 +458,10 @@ test "occurs: recursive tag union (v = TagUnion { Foo(v) } with ext = v)" {
447458
const err_chain = scratch.errChainSlice();
448459
try std.testing.expectEqual(1, err_chain.len);
449460
try std.testing.expectEqual(linked_list, err_chain[0]);
461+
462+
for (scratch.visited.items.items[0..]) |visited_desc_idx| {
463+
try std.testing.expectEqual(Mark.none, types_store.getDesc(visited_desc_idx).mark);
464+
}
450465
}
451466

452467
test "occurs: nested recursive tag union (v = TagUnion { Cons(elem, Box(v)) } )" {
@@ -486,4 +501,8 @@ test "occurs: nested recursive tag union (v = TagUnion { Cons(elem, Box(v)) } )"
486501
try std.testing.expect(err_chain.len == 2);
487502
try std.testing.expectEqual(err_chain[0], boxed_linked_list);
488503
try std.testing.expectEqual(err_chain[1], linked_list);
504+
505+
for (scratch.visited.items.items[0..]) |visited_desc_idx| {
506+
try std.testing.expectEqual(Mark.none, types_store.getDesc(visited_desc_idx).mark);
507+
}
489508
}

src/types/store.zig

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,9 +432,13 @@ const DescStore = struct {
432432
}
433433

434434
/// A type-safe index into the store
435+
/// This type is made public below
435436
const Idx = enum(u32) { _ };
436437
};
437438

439+
/// An index into the desc store
440+
pub const DescStoreIdx = DescStore.Idx;
441+
438442
// path compression
439443

440444
test "resolveVarAndCompressPath - flattens redirect chain to flex_var" {

src/types/types.zig

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,16 @@ pub const Rank = enum(u4) {
5454
/// A type variable mark
5555
///
5656
/// Marks are temporary annotations used during various phases of type inference
57-
/// and type checking to track state and avoid redundant work.
57+
/// and type checking to track state.
5858
///
5959
/// Some places `Mark` is used:
60+
/// * Marking variables as visited in occurs checks to avoid redundant work
6061
/// * Marking variables for generalizing during solving
6162
pub const Mark = enum(u32) {
6263
const Self = @This();
6364

64-
none = 0,
65+
visited = 0,
66+
none = 1,
6567
_,
6668

6769
/// Get the next mark

0 commit comments

Comments
 (0)