aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaytan <laytanlaats@hotmail.com>2025-06-12 21:51:34 +0200
committerGitHub <noreply@github.com>2025-06-12 21:51:34 +0200
commitfc7fc4d5cdcdb8c67f422b04e8992a1fff966235 (patch)
treedd9d39ebf0c877efcd8be23a5f287793c0cec543
parent0ed6cdc98eead010b448a10ac2c45d2695563be9 (diff)
parent3c3fd6e580b017b2243303221709856b9c663a5c (diff)
Merge pull request #5289 from JackMordaunt/jfm-sync_chan_refactor
Jfm sync chan refactor
-rw-r--r--core/sync/chan/chan.odin105
-rw-r--r--tests/core/sync/chan/test_core_sync_chan.odin177
2 files changed, 251 insertions, 31 deletions
diff --git a/core/sync/chan/chan.odin b/core/sync/chan/chan.odin
index eca4c28d7..c5a4cf317 100644
--- a/core/sync/chan/chan.odin
+++ b/core/sync/chan/chan.odin
@@ -7,6 +7,14 @@ import "core:mem"
import "core:sync"
import "core:math/rand"
+when ODIN_TEST {
+/*
+Hook for testing _try_select_raw allowing the test harness to manipulate the
+channels prior to the select actually operating on them.
+*/
+__try_select_raw_pause : proc() = nil
+}
+
/*
Determines what operations `Chan` supports.
*/
@@ -1105,15 +1113,27 @@ can_send :: proc "contextless" (c: ^Raw_Chan) -> bool {
return c.w_waiting == 0
}
+/*
+Specifies the direction of the selected channel.
+*/
+Select_Status :: enum {
+ None,
+ Recv,
+ Send,
+}
+
/*
-Attempts to either send or receive messages on the specified channels.
+Attempts to either send or receive messages on the specified channels without blocking.
-`select_raw` first identifies which channels have messages ready to be received
+`try_select_raw` first identifies which channels have messages ready to be received
and which are available for sending. It then randomly selects one operation
(either a send or receive) to perform.
+If no channels have messages ready, the procedure is a noop.
+
Note: Each message in `send_msgs` corresponds to the send channel at the same index in `sends`.
+If the message is nil, corresponding send channel will be skipped.
**Inputs**
- `recv`: A slice of channels to read from
@@ -1145,18 +1165,18 @@ Example:
// where the value from the read should be stored
received_value: int
- idx, ok := chan.select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
+ idx, ok := chan.try_select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
fmt.println("SELECT: ", idx, ok)
fmt.println("RECEIVED VALUE ", received_value)
- idx, ok = chan.select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
+ idx, ok = chan.try_select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
fmt.println("SELECT: ", idx, ok)
fmt.println("RECEIVED VALUE ", received_value)
// closing of a channel also affects the select operation
chan.close(c)
- idx, ok = chan.select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
+ idx, ok = chan.try_select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
fmt.println("SELECT: ", idx, ok)
}
@@ -1170,7 +1190,7 @@ Output:
*/
@(require_results)
-select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: []rawptr, recv_out: rawptr) -> (select_idx: int, ok: bool) #no_bounds_check {
+try_select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: []rawptr, recv_out: rawptr) -> (select_idx: int, status: Select_Status) #no_bounds_check {
Select_Op :: struct {
idx: int, // local to the slice that was given
is_recv: bool,
@@ -1178,43 +1198,66 @@ select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: []
candidate_count := builtin.len(recvs)+builtin.len(sends)
candidates := ([^]Select_Op)(intrinsics.alloca(candidate_count*size_of(Select_Op), align_of(Select_Op)))
- count := 0
- for c, i in recvs {
- if can_recv(c) {
- candidates[count] = {
- is_recv = true,
- idx = i,
+ try_loop: for {
+ count := 0
+
+ for c, i in recvs {
+ if can_recv(c) {
+ candidates[count] = {
+ is_recv = true,
+ idx = i,
+ }
+ count += 1
}
- count += 1
}
- }
- for c, i in sends {
- if can_send(c) {
- candidates[count] = {
- is_recv = false,
- idx = i,
+ for c, i in sends {
+ if i > builtin.len(send_msgs)-1 || send_msgs[i] == nil {
+ continue
+ }
+ if can_send(c) {
+ candidates[count] = {
+ is_recv = false,
+ idx = i,
+ }
+ count += 1
}
- count += 1
}
- }
- if count == 0 {
- return
- }
+ if count == 0 {
+ return -1, .None
+ }
+
+ when ODIN_TEST {
+ if __try_select_raw_pause != nil {
+ __try_select_raw_pause()
+ }
+ }
- select_idx = rand.int_max(count) if count > 0 else 0
+ candidate_idx := rand.int_max(count) if count > 0 else 0
- sel := candidates[select_idx]
- if sel.is_recv {
- ok = recv_raw(recvs[sel.idx], recv_out)
- } else {
- ok = send_raw(sends[sel.idx], send_msgs[sel.idx])
+ sel := candidates[candidate_idx]
+ if sel.is_recv {
+ status = .Recv
+ if !try_recv_raw(recvs[sel.idx], recv_out) {
+ continue try_loop
+ }
+ } else {
+ status = .Send
+ if !try_send_raw(sends[sel.idx], send_msgs[sel.idx]) {
+ continue try_loop
+ }
+ }
+
+ return sel.idx, status
}
- return
}
+@(require_results, deprecated = "use try_select_raw")
+select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: []rawptr, recv_out: rawptr) -> (select_idx: int, status: Select_Status) #no_bounds_check {
+ return try_select_raw(recvs, sends, send_msgs, recv_out)
+}
/*
`Raw_Queue` is a non-thread-safe queue implementation designed to store messages
diff --git a/tests/core/sync/chan/test_core_sync_chan.odin b/tests/core/sync/chan/test_core_sync_chan.odin
index 9b8d9b354..e8bb553b1 100644
--- a/tests/core/sync/chan/test_core_sync_chan.odin
+++ b/tests/core/sync/chan/test_core_sync_chan.odin
@@ -272,3 +272,180 @@ test_accept_message_from_closed_buffered_chan :: proc(t: ^testing.T) {
testing.expect_value(t, result, 64)
testing.expect(t, ok)
}
+
+// Ensures that if any input channel is eligible to receive or send, the try_select_raw
+// operation will process it.
+@test
+test_try_select_raw_happy :: proc(t: ^testing.T) {
+ testing.set_fail_timeout(t, FAIL_TIME)
+
+ recv1, recv1_err := chan.create(chan.Chan(int), context.allocator)
+
+ assert(recv1_err == nil, "allocation failed")
+ defer chan.destroy(recv1)
+
+ recv2, recv2_err := chan.create(chan.Chan(int), 1, context.allocator)
+
+ assert(recv2_err == nil, "allocation failed")
+ defer chan.destroy(recv2)
+
+ send1, send1_err := chan.create(chan.Chan(int), 1, context.allocator)
+
+ assert(send1_err == nil, "allocation failed")
+ defer chan.destroy(send1)
+
+ msg := 42
+
+ // Preload recv2 to make it eligible for selection.
+ testing.expect_value(t, chan.send(recv2, msg), true)
+
+ recvs := [?]^chan.Raw_Chan{recv1, recv2}
+ sends := [?]^chan.Raw_Chan{send1}
+ msgs := [?]rawptr{&msg}
+ received_value: int
+
+ iteration_count := 0
+ did_none_count := 0
+ did_send_count := 0
+ did_receive_count := 0
+
+ // This loop is expected to iterate three times. Twice to do the receive and
+ // send operations, and a third time to exit.
+ receive_loop: for {
+
+ iteration_count += 1
+
+ idx, status := chan.try_select_raw(recvs[:], sends[:], msgs[:], &received_value)
+
+ switch status {
+ case .None:
+ did_none_count += 1
+ break receive_loop
+
+ case .Recv:
+ did_receive_count += 1
+ testing.expect_value(t, idx, 1)
+ testing.expect_value(t, received_value, msg)
+ received_value = 0
+
+ case .Send:
+ did_send_count += 1
+ testing.expect_value(t, idx, 0)
+ v, ok := chan.try_recv(send1)
+ testing.expect_value(t, ok, true)
+ testing.expect_value(t, v, msg)
+ msgs[0] = nil // nil out the message to avoid constantly resending the same value.
+ }
+ }
+
+ testing.expect_value(t, iteration_count, 3)
+ testing.expect_value(t, did_none_count, 1)
+ testing.expect_value(t, did_receive_count, 1)
+ testing.expect_value(t, did_send_count, 1)
+}
+
+// Ensures that if no input channels are eligible to receive or send, the
+// try_select_raw operation does not block.
+@test
+test_try_select_raw_default_state :: proc(t: ^testing.T) {
+ testing.set_fail_timeout(t, FAIL_TIME)
+
+ recv1, recv1_err := chan.create(chan.Chan(int), context.allocator)
+
+ assert(recv1_err == nil, "allocation failed")
+ defer chan.destroy(recv1)
+
+ recv2, recv2_err := chan.create(chan.Chan(int), context.allocator)
+
+ assert(recv2_err == nil, "allocation failed")
+ defer chan.destroy(recv2)
+
+ recvs := [?]^chan.Raw_Chan{recv1, recv2}
+ received_value: int
+
+ idx, status := chan.try_select_raw(recvs[:], nil, nil, &received_value)
+
+ testing.expect_value(t, idx, -1)
+ testing.expect_value(t, status, chan.Select_Status.None)
+}
+
+// Ensures that the operation will not block even if the input channels are
+// consumed by a competing thread; that is, a value is received from another
+// thread between calls to can_{send,recv} and try_{send,recv}_raw.
+@test
+test_try_select_raw_no_toctou :: proc(t: ^testing.T) {
+ testing.set_fail_timeout(t, FAIL_TIME)
+
+ // Trigger will be used to coordinate between the thief and the try_select.
+ trigger, trigger_err := chan.create(chan.Chan(any), context.allocator)
+
+ assert(trigger_err == nil, "allocation failed")
+ defer chan.destroy(trigger)
+
+ @(static)
+ __global_context_for_test: rawptr
+
+ __global_context_for_test = &trigger
+ defer __global_context_for_test = nil
+
+ // Setup the pause proc. This will be invoked after the input channels are
+ // checked for eligibility but before any channel operations are attempted.
+ chan.__try_select_raw_pause = proc() {
+ trigger := (cast(^chan.Chan(any))(__global_context_for_test))^
+
+ // Notify the thief that we are paused so that it can steal the value.
+ _ = chan.send(trigger, "signal")
+
+ // Wait for comfirmation of the burglary.
+ _, _ = chan.recv(trigger)
+ }
+
+ defer chan.__try_select_raw_pause = nil
+
+ recv1, recv1_err := chan.create(chan.Chan(int), 1, context.allocator)
+
+ assert(recv1_err == nil, "allocation failed")
+ defer chan.destroy(recv1)
+
+ Context :: struct {
+ recv1: chan.Chan(int),
+ trigger: chan.Chan(any),
+ }
+
+ ctx := Context{
+ recv1 = recv1,
+ trigger = trigger,
+ }
+
+ // Spin up a thread that will steal the value from the input channel after
+ // try_select has already considered it eligible for selection.
+ thief := thread.create_and_start_with_poly_data(ctx, proc(ctx: Context) {
+ // Wait for eligibility check.
+ _, _ = chan.recv(ctx.trigger)
+
+ // Steal the value.
+ v, ok := chan.recv(ctx.recv1)
+
+ assert(ok, "recv1: expected to receive a value")
+ assert(v == 42, "recv1: unexpected receive value")
+
+ // Notify select that we have stolen the value and that it can proceed.
+ _ = chan.send(ctx.trigger, "signal")
+ })
+
+ recvs := [?]^chan.Raw_Chan{recv1}
+ received_value: int
+
+ // Ensure channel is eligible prior to entering the select.
+ testing.expect_value(t, chan.send(recv1, 42), true)
+
+ // Execute the try_select_raw, assert that we don't block, and that we receive
+ // .None status since the value was stolen by the other thread.
+ idx, status := chan.try_select_raw(recvs[:], nil, nil, &received_value)
+
+ testing.expect_value(t, idx, -1)
+ testing.expect_value(t, status, chan.Select_Status.None)
+
+ thread.join(thief)
+ thread.destroy(thief)
+}