Skip to content

Commit fc7fc4d

Browse files
authored
Merge pull request odin-lang#5289 from JackMordaunt/jfm-sync_chan_refactor
Jfm sync chan refactor
2 parents 0ed6cdc + 3c3fd6e commit fc7fc4d

File tree

2 files changed

+251
-31
lines changed

2 files changed

+251
-31
lines changed

core/sync/chan/chan.odin

Lines changed: 74 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@ import "core:mem"
77
import "core:sync"
88
import "core:math/rand"
99

10+
when ODIN_TEST {
11+
/*
12+
Hook for testing _try_select_raw allowing the test harness to manipulate the
13+
channels prior to the select actually operating on them.
14+
*/
15+
__try_select_raw_pause : proc() = nil
16+
}
17+
1018
/*
1119
Determines what operations `Chan` supports.
1220
*/
@@ -1105,15 +1113,27 @@ can_send :: proc "contextless" (c: ^Raw_Chan) -> bool {
11051113
return c.w_waiting == 0
11061114
}
11071115

1116+
/*
1117+
Specifies the direction of the selected channel.
1118+
*/
1119+
Select_Status :: enum {
1120+
None,
1121+
Recv,
1122+
Send,
1123+
}
1124+
11081125

11091126
/*
1110-
Attempts to either send or receive messages on the specified channels.
1127+
Attempts to either send or receive messages on the specified channels without blocking.
11111128
1112-
`select_raw` first identifies which channels have messages ready to be received
1129+
`try_select_raw` first identifies which channels have messages ready to be received
11131130
and which are available for sending. It then randomly selects one operation
11141131
(either a send or receive) to perform.
11151132
1133+
If no channels have messages ready, the procedure is a noop.
1134+
11161135
Note: Each message in `send_msgs` corresponds to the send channel at the same index in `sends`.
1136+
If the message is nil, corresponding send channel will be skipped.
11171137
11181138
**Inputs**
11191139
- `recv`: A slice of channels to read from
@@ -1145,18 +1165,18 @@ Example:
11451165
// where the value from the read should be stored
11461166
received_value: int
11471167
1148-
idx, ok := chan.select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
1168+
idx, ok := chan.try_select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
11491169
fmt.println("SELECT: ", idx, ok)
11501170
fmt.println("RECEIVED VALUE ", received_value)
11511171
1152-
idx, ok = chan.select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
1172+
idx, ok = chan.try_select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
11531173
fmt.println("SELECT: ", idx, ok)
11541174
fmt.println("RECEIVED VALUE ", received_value)
11551175
11561176
// closing of a channel also affects the select operation
11571177
chan.close(c)
11581178
1159-
idx, ok = chan.select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
1179+
idx, ok = chan.try_select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
11601180
fmt.println("SELECT: ", idx, ok)
11611181
}
11621182
@@ -1170,51 +1190,74 @@ Output:
11701190
11711191
*/
11721192
@(require_results)
1173-
select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: []rawptr, recv_out: rawptr) -> (select_idx: int, ok: bool) #no_bounds_check {
1193+
try_select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: []rawptr, recv_out: rawptr) -> (select_idx: int, status: Select_Status) #no_bounds_check {
11741194
Select_Op :: struct {
11751195
idx: int, // local to the slice that was given
11761196
is_recv: bool,
11771197
}
11781198

11791199
candidate_count := builtin.len(recvs)+builtin.len(sends)
11801200
candidates := ([^]Select_Op)(intrinsics.alloca(candidate_count*size_of(Select_Op), align_of(Select_Op)))
1181-
count := 0
11821201

1183-
for c, i in recvs {
1184-
if can_recv(c) {
1185-
candidates[count] = {
1186-
is_recv = true,
1187-
idx = i,
1202+
try_loop: for {
1203+
count := 0
1204+
1205+
for c, i in recvs {
1206+
if can_recv(c) {
1207+
candidates[count] = {
1208+
is_recv = true,
1209+
idx = i,
1210+
}
1211+
count += 1
11881212
}
1189-
count += 1
11901213
}
1191-
}
11921214

1193-
for c, i in sends {
1194-
if can_send(c) {
1195-
candidates[count] = {
1196-
is_recv = false,
1197-
idx = i,
1215+
for c, i in sends {
1216+
if i > builtin.len(send_msgs)-1 || send_msgs[i] == nil {
1217+
continue
1218+
}
1219+
if can_send(c) {
1220+
candidates[count] = {
1221+
is_recv = false,
1222+
idx = i,
1223+
}
1224+
count += 1
11981225
}
1199-
count += 1
12001226
}
1201-
}
12021227

1203-
if count == 0 {
1204-
return
1205-
}
1228+
if count == 0 {
1229+
return -1, .None
1230+
}
1231+
1232+
when ODIN_TEST {
1233+
if __try_select_raw_pause != nil {
1234+
__try_select_raw_pause()
1235+
}
1236+
}
12061237

1207-
select_idx = rand.int_max(count) if count > 0 else 0
1238+
candidate_idx := rand.int_max(count) if count > 0 else 0
12081239

1209-
sel := candidates[select_idx]
1210-
if sel.is_recv {
1211-
ok = recv_raw(recvs[sel.idx], recv_out)
1212-
} else {
1213-
ok = send_raw(sends[sel.idx], send_msgs[sel.idx])
1240+
sel := candidates[candidate_idx]
1241+
if sel.is_recv {
1242+
status = .Recv
1243+
if !try_recv_raw(recvs[sel.idx], recv_out) {
1244+
continue try_loop
1245+
}
1246+
} else {
1247+
status = .Send
1248+
if !try_send_raw(sends[sel.idx], send_msgs[sel.idx]) {
1249+
continue try_loop
1250+
}
1251+
}
1252+
1253+
return sel.idx, status
12141254
}
1215-
return
12161255
}
12171256

1257+
@(require_results, deprecated = "use try_select_raw")
1258+
select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: []rawptr, recv_out: rawptr) -> (select_idx: int, status: Select_Status) #no_bounds_check {
1259+
return try_select_raw(recvs, sends, send_msgs, recv_out)
1260+
}
12181261

12191262
/*
12201263
`Raw_Queue` is a non-thread-safe queue implementation designed to store messages

tests/core/sync/chan/test_core_sync_chan.odin

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,180 @@ test_accept_message_from_closed_buffered_chan :: proc(t: ^testing.T) {
272272
testing.expect_value(t, result, 64)
273273
testing.expect(t, ok)
274274
}
275+
276+
// Ensures that if any input channel is eligible to receive or send, the try_select_raw
277+
// operation will process it.
278+
@test
279+
test_try_select_raw_happy :: proc(t: ^testing.T) {
280+
testing.set_fail_timeout(t, FAIL_TIME)
281+
282+
recv1, recv1_err := chan.create(chan.Chan(int), context.allocator)
283+
284+
assert(recv1_err == nil, "allocation failed")
285+
defer chan.destroy(recv1)
286+
287+
recv2, recv2_err := chan.create(chan.Chan(int), 1, context.allocator)
288+
289+
assert(recv2_err == nil, "allocation failed")
290+
defer chan.destroy(recv2)
291+
292+
send1, send1_err := chan.create(chan.Chan(int), 1, context.allocator)
293+
294+
assert(send1_err == nil, "allocation failed")
295+
defer chan.destroy(send1)
296+
297+
msg := 42
298+
299+
// Preload recv2 to make it eligible for selection.
300+
testing.expect_value(t, chan.send(recv2, msg), true)
301+
302+
recvs := [?]^chan.Raw_Chan{recv1, recv2}
303+
sends := [?]^chan.Raw_Chan{send1}
304+
msgs := [?]rawptr{&msg}
305+
received_value: int
306+
307+
iteration_count := 0
308+
did_none_count := 0
309+
did_send_count := 0
310+
did_receive_count := 0
311+
312+
// This loop is expected to iterate three times. Twice to do the receive and
313+
// send operations, and a third time to exit.
314+
receive_loop: for {
315+
316+
iteration_count += 1
317+
318+
idx, status := chan.try_select_raw(recvs[:], sends[:], msgs[:], &received_value)
319+
320+
switch status {
321+
case .None:
322+
did_none_count += 1
323+
break receive_loop
324+
325+
case .Recv:
326+
did_receive_count += 1
327+
testing.expect_value(t, idx, 1)
328+
testing.expect_value(t, received_value, msg)
329+
received_value = 0
330+
331+
case .Send:
332+
did_send_count += 1
333+
testing.expect_value(t, idx, 0)
334+
v, ok := chan.try_recv(send1)
335+
testing.expect_value(t, ok, true)
336+
testing.expect_value(t, v, msg)
337+
msgs[0] = nil // nil out the message to avoid constantly resending the same value.
338+
}
339+
}
340+
341+
testing.expect_value(t, iteration_count, 3)
342+
testing.expect_value(t, did_none_count, 1)
343+
testing.expect_value(t, did_receive_count, 1)
344+
testing.expect_value(t, did_send_count, 1)
345+
}
346+
347+
// Ensures that if no input channels are eligible to receive or send, the
348+
// try_select_raw operation does not block.
349+
@test
350+
test_try_select_raw_default_state :: proc(t: ^testing.T) {
351+
testing.set_fail_timeout(t, FAIL_TIME)
352+
353+
recv1, recv1_err := chan.create(chan.Chan(int), context.allocator)
354+
355+
assert(recv1_err == nil, "allocation failed")
356+
defer chan.destroy(recv1)
357+
358+
recv2, recv2_err := chan.create(chan.Chan(int), context.allocator)
359+
360+
assert(recv2_err == nil, "allocation failed")
361+
defer chan.destroy(recv2)
362+
363+
recvs := [?]^chan.Raw_Chan{recv1, recv2}
364+
received_value: int
365+
366+
idx, status := chan.try_select_raw(recvs[:], nil, nil, &received_value)
367+
368+
testing.expect_value(t, idx, -1)
369+
testing.expect_value(t, status, chan.Select_Status.None)
370+
}
371+
372+
// Ensures that the operation will not block even if the input channels are
373+
// consumed by a competing thread; that is, a value is received from another
374+
// thread between calls to can_{send,recv} and try_{send,recv}_raw.
375+
@test
376+
test_try_select_raw_no_toctou :: proc(t: ^testing.T) {
377+
testing.set_fail_timeout(t, FAIL_TIME)
378+
379+
// Trigger will be used to coordinate between the thief and the try_select.
380+
trigger, trigger_err := chan.create(chan.Chan(any), context.allocator)
381+
382+
assert(trigger_err == nil, "allocation failed")
383+
defer chan.destroy(trigger)
384+
385+
@(static)
386+
__global_context_for_test: rawptr
387+
388+
__global_context_for_test = &trigger
389+
defer __global_context_for_test = nil
390+
391+
// Setup the pause proc. This will be invoked after the input channels are
392+
// checked for eligibility but before any channel operations are attempted.
393+
chan.__try_select_raw_pause = proc() {
394+
trigger := (cast(^chan.Chan(any))(__global_context_for_test))^
395+
396+
// Notify the thief that we are paused so that it can steal the value.
397+
_ = chan.send(trigger, "signal")
398+
399+
// Wait for comfirmation of the burglary.
400+
_, _ = chan.recv(trigger)
401+
}
402+
403+
defer chan.__try_select_raw_pause = nil
404+
405+
recv1, recv1_err := chan.create(chan.Chan(int), 1, context.allocator)
406+
407+
assert(recv1_err == nil, "allocation failed")
408+
defer chan.destroy(recv1)
409+
410+
Context :: struct {
411+
recv1: chan.Chan(int),
412+
trigger: chan.Chan(any),
413+
}
414+
415+
ctx := Context{
416+
recv1 = recv1,
417+
trigger = trigger,
418+
}
419+
420+
// Spin up a thread that will steal the value from the input channel after
421+
// try_select has already considered it eligible for selection.
422+
thief := thread.create_and_start_with_poly_data(ctx, proc(ctx: Context) {
423+
// Wait for eligibility check.
424+
_, _ = chan.recv(ctx.trigger)
425+
426+
// Steal the value.
427+
v, ok := chan.recv(ctx.recv1)
428+
429+
assert(ok, "recv1: expected to receive a value")
430+
assert(v == 42, "recv1: unexpected receive value")
431+
432+
// Notify select that we have stolen the value and that it can proceed.
433+
_ = chan.send(ctx.trigger, "signal")
434+
})
435+
436+
recvs := [?]^chan.Raw_Chan{recv1}
437+
received_value: int
438+
439+
// Ensure channel is eligible prior to entering the select.
440+
testing.expect_value(t, chan.send(recv1, 42), true)
441+
442+
// Execute the try_select_raw, assert that we don't block, and that we receive
443+
// .None status since the value was stolen by the other thread.
444+
idx, status := chan.try_select_raw(recvs[:], nil, nil, &received_value)
445+
446+
testing.expect_value(t, idx, -1)
447+
testing.expect_value(t, status, chan.Select_Status.None)
448+
449+
thread.join(thief)
450+
thread.destroy(thief)
451+
}

0 commit comments

Comments
 (0)