@@ -272,3 +272,180 @@ test_accept_message_from_closed_buffered_chan :: proc(t: ^testing.T) {
272272 testing.expect_value (t, result, 64 )
273273 testing.expect (t, ok)
274274}
275+
276+ // Ensures that if any input channel is eligible to receive or send, the try_select_raw
277+ // operation will process it.
278+ @test
279+ test_try_select_raw_happy :: proc (t: ^testing.T) {
280+ testing.set_fail_timeout (t, FAIL_TIME)
281+
282+ recv1, recv1_err := chan.create (chan.Chan (int ), context .allocator)
283+
284+ assert (recv1_err == nil , " allocation failed" )
285+ defer chan.destroy (recv1)
286+
287+ recv2, recv2_err := chan.create (chan.Chan (int ), 1 , context .allocator)
288+
289+ assert (recv2_err == nil , " allocation failed" )
290+ defer chan.destroy (recv2)
291+
292+ send1, send1_err := chan.create (chan.Chan (int ), 1 , context .allocator)
293+
294+ assert (send1_err == nil , " allocation failed" )
295+ defer chan.destroy (send1)
296+
297+ msg := 42
298+
299+ // Preload recv2 to make it eligible for selection.
300+ testing.expect_value (t, chan.send (recv2, msg), true )
301+
302+ recvs := [?]^chan.Raw_Chan{recv1, recv2}
303+ sends := [?]^chan.Raw_Chan{send1}
304+ msgs := [?]rawptr {&msg}
305+ received_value: int
306+
307+ iteration_count := 0
308+ did_none_count := 0
309+ did_send_count := 0
310+ did_receive_count := 0
311+
312+ // This loop is expected to iterate three times. Twice to do the receive and
313+ // send operations, and a third time to exit.
314+ receive_loop: for {
315+
316+ iteration_count += 1
317+
318+ idx, status := chan.try_select_raw (recvs[:], sends[:], msgs[:], &received_value)
319+
320+ switch status {
321+ case .None:
322+ did_none_count += 1
323+ break receive_loop
324+
325+ case .Recv:
326+ did_receive_count += 1
327+ testing.expect_value (t, idx, 1 )
328+ testing.expect_value (t, received_value, msg)
329+ received_value = 0
330+
331+ case .Send:
332+ did_send_count += 1
333+ testing.expect_value (t, idx, 0 )
334+ v, ok := chan.try_recv (send1)
335+ testing.expect_value (t, ok, true )
336+ testing.expect_value (t, v, msg)
337+ msgs[0 ] = nil // nil out the message to avoid constantly resending the same value.
338+ }
339+ }
340+
341+ testing.expect_value (t, iteration_count, 3 )
342+ testing.expect_value (t, did_none_count, 1 )
343+ testing.expect_value (t, did_receive_count, 1 )
344+ testing.expect_value (t, did_send_count, 1 )
345+ }
346+
347+ // Ensures that if no input channels are eligible to receive or send, the
348+ // try_select_raw operation does not block.
349+ @test
350+ test_try_select_raw_default_state :: proc (t: ^testing.T) {
351+ testing.set_fail_timeout (t, FAIL_TIME)
352+
353+ recv1, recv1_err := chan.create (chan.Chan (int ), context .allocator)
354+
355+ assert (recv1_err == nil , " allocation failed" )
356+ defer chan.destroy (recv1)
357+
358+ recv2, recv2_err := chan.create (chan.Chan (int ), context .allocator)
359+
360+ assert (recv2_err == nil , " allocation failed" )
361+ defer chan.destroy (recv2)
362+
363+ recvs := [?]^chan.Raw_Chan{recv1, recv2}
364+ received_value: int
365+
366+ idx, status := chan.try_select_raw (recvs[:], nil , nil , &received_value)
367+
368+ testing.expect_value (t, idx, -1 )
369+ testing.expect_value (t, status, chan.Select_Status.None)
370+ }
371+
372+ // Ensures that the operation will not block even if the input channels are
373+ // consumed by a competing thread; that is, a value is received from another
374+ // thread between calls to can_{send,recv} and try_{send,recv}_raw.
375+ @test
376+ test_try_select_raw_no_toctou :: proc (t: ^testing.T) {
377+ testing.set_fail_timeout (t, FAIL_TIME)
378+
379+ // Trigger will be used to coordinate between the thief and the try_select.
380+ trigger, trigger_err := chan.create (chan.Chan (any ), context .allocator)
381+
382+ assert (trigger_err == nil , " allocation failed" )
383+ defer chan.destroy (trigger)
384+
385+ @(static)
386+ __global_context_for_test: rawptr
387+
388+ __global_context_for_test = &trigger
389+ defer __global_context_for_test = nil
390+
391+ // Setup the pause proc. This will be invoked after the input channels are
392+ // checked for eligibility but before any channel operations are attempted.
393+ chan.__try_select_raw_pause = proc () {
394+ trigger := (cast (^chan.Chan (any ))(__global_context_for_test))^
395+
396+ // Notify the thief that we are paused so that it can steal the value.
397+ _ = chan.send (trigger, " signal" )
398+
399+ // Wait for comfirmation of the burglary.
400+ _, _ = chan.recv (trigger)
401+ }
402+
403+ defer chan.__try_select_raw_pause = nil
404+
405+ recv1, recv1_err := chan.create (chan.Chan (int ), 1 , context .allocator)
406+
407+ assert (recv1_err == nil , " allocation failed" )
408+ defer chan.destroy (recv1)
409+
410+ Context :: struct {
411+ recv1: chan.Chan (int ),
412+ trigger: chan.Chan (any ),
413+ }
414+
415+ ctx := Context{
416+ recv1 = recv1,
417+ trigger = trigger,
418+ }
419+
420+ // Spin up a thread that will steal the value from the input channel after
421+ // try_select has already considered it eligible for selection.
422+ thief := thread.create_and_start_with_poly_data (ctx, proc (ctx: Context) {
423+ // Wait for eligibility check.
424+ _, _ = chan.recv (ctx.trigger)
425+
426+ // Steal the value.
427+ v, ok := chan.recv (ctx.recv1)
428+
429+ assert (ok, " recv1: expected to receive a value" )
430+ assert (v == 42 , " recv1: unexpected receive value" )
431+
432+ // Notify select that we have stolen the value and that it can proceed.
433+ _ = chan.send (ctx.trigger, " signal" )
434+ })
435+
436+ recvs := [?]^chan.Raw_Chan{recv1}
437+ received_value: int
438+
439+ // Ensure channel is eligible prior to entering the select.
440+ testing.expect_value (t, chan.send (recv1, 42 ), true )
441+
442+ // Execute the try_select_raw, assert that we don't block, and that we receive
443+ // .None status since the value was stolen by the other thread.
444+ idx, status := chan.try_select_raw (recvs[:], nil , nil , &received_value)
445+
446+ testing.expect_value (t, idx, -1 )
447+ testing.expect_value (t, status, chan.Select_Status.None)
448+
449+ thread.join (thief)
450+ thread.destroy (thief)
451+ }
0 commit comments