11import sys
22
3+ import dask .array
34import msgpack
45import numpy as np
56import pytest
1213)
1314
1415
16+ def send_ws_updates (client , update_func , start = 1 , count = 1 , persist = None ):
17+ """Helper to send updates via Tiled client in websocket tests."""
18+ for i in range (start , start + count ):
19+ new_arr = np .arange (10 ) + i
20+ update_func (client , new_arr , i , persist = persist )
21+
22+
23+ # An update_func for send_ws_updates
24+ def overwrite_array (client , new_arr , seq_num , persist = None ):
25+ _ = seq_num # seq_num is unused for these updates
26+ client .write (new_arr , persist = persist )
27+
28+
29+ # An update_func for send_ws_updates
30+ def write_array_block (client , new_arr , seq_num , persist = None ):
31+ client .write_block (new_arr , block = (seq_num - 1 , 0 ), persist = persist )
32+
33+
34+ # An update_func for send_ws_updates
35+ def patch_array (client , new_arr , seq_num , persist = None ):
36+ _ = seq_num # seq_num is unused for these updates
37+ client .patch (new_arr , offset = (0 ,), persist = persist )
38+
39+
40+ # An update_func for send_ws_updates
41+ def append_array (client , new_arr , seq_num , persist = None ):
42+ client .patch (new_arr , offset = (10 * seq_num ,), extend = True , persist = persist )
43+
44+
45+ def receive_ws_updates (websocket , count = 1 ):
46+ """Helper to receive updates in websocket tests."""
47+ # Receive all updates
48+ received = []
49+ for _ in range (count + 1 ): # +1 for schema
50+ msg_bytes = websocket .receive_bytes ()
51+ msg = msgpack .unpackb (msg_bytes )
52+ received .append (msg )
53+
54+ # Verify all messages received (schema + n updates)
55+ assert len (received ) == count + 1
56+
57+ return received
58+
59+
60+ def verify_ws_updates (received , start = 1 , chunked = False ):
61+ """Verify that we received messages with the expected data"""
62+ for i , msg in enumerate (received ):
63+ if i == 0 : # schema
64+ assert "type" in msg
65+ assert "version" in msg
66+ else :
67+ assert "type" in msg
68+ assert "timestamp" in msg
69+ assert "payload" in msg
70+ if chunked :
71+ assert msg ["shape" ] == [1 , 10 ]
72+ else :
73+ assert msg ["shape" ] == [10 ]
74+
75+ # Verify payload contains the expected array data
76+ payload_array = np .frombuffer (msg ["payload" ], dtype = np .int64 )
77+ expected_array = np .arange (10 ) + (start - 1 ) + i
78+ np .testing .assert_array_equal (payload_array , expected_array )
79+
80+
1581def test_subscribe_immediately_after_creation_websockets (tiled_websocket_context ):
1682 context = tiled_websocket_context
1783 client = from_context (context )
@@ -26,36 +92,12 @@ def test_subscribe_immediately_after_creation_websockets(tiled_websocket_context
2692 "/api/v1/stream/single/test_stream_immediate?envelope_format=msgpack" ,
2793 headers = {"Authorization" : "Apikey secret" },
2894 ) as websocket :
29- # Write updates using Tiled client
30- for i in range (1 , 4 ):
31- new_arr = np .arange (10 ) + i
32- streaming_node .write (new_arr )
33-
34- # Receive all updates
35- received = []
36- for _ in range (3 ):
37- msg_bytes = websocket .receive_bytes ()
38- msg = msgpack .unpackb (msg_bytes )
39- received .append (msg )
40-
41- # Verify all updates received in order
42- assert len (received ) == 3
95+ # Send 3 updates using Tiled client that overwrite the array
96+ send_ws_updates (streaming_node , overwrite_array , count = 3 )
4397
44- # Check that we received messages with the expected data
45- for i , msg in enumerate (received ):
46- if i == 0 : # schema
47- assert "type" in msg
48- assert "version" in msg
49- else :
50- assert "type" in msg
51- assert "timestamp" in msg
52- assert "payload" in msg
53- assert msg ["shape" ] == [10 ]
54-
55- # Verify payload contains the expected array data
56- payload_array = np .frombuffer (msg ["payload" ], dtype = np .int64 )
57- expected_array = np .arange (10 ) + i
58- np .testing .assert_array_equal (payload_array , expected_array )
98+ # Receive and validate all updates
99+ received = receive_ws_updates (websocket , count = 3 )
100+ verify_ws_updates (received )
59101
60102
61103def test_websocket_connection_to_non_existent_node (tiled_websocket_context ):
@@ -93,38 +135,13 @@ def test_subscribe_after_first_update_websockets(tiled_websocket_context):
93135 "/api/v1/stream/single/test_stream_after_update?envelope_format=msgpack" ,
94136 headers = {"Authorization" : "Apikey secret" },
95137 ) as websocket :
96- # Write more updates
97- for i in range (2 , 4 ):
98- new_arr = np .arange (10 ) + i
99- streaming_node .write (new_arr )
138+ # Send 2 more updates that overwrite the array
139+ send_ws_updates (streaming_node , overwrite_array , start = 2 , count = 2 )
100140
101141 # Should only receive the 2 new updates
102- received = []
103- for _ in range (2 ):
104- msg_bytes = websocket .receive_bytes ()
105- msg = msgpack .unpackb (msg_bytes )
106- received .append (msg )
107-
108- # Verify only new updates received
109- assert len (received ) == 2
110-
111- # Check that we received messages with the expected data
112- for i , msg in enumerate (received ):
113- if i == 0 : # schema
114- assert "type" in msg
115- assert "version" in msg
116- else :
117- assert "type" in msg
118- assert "timestamp" in msg
119- assert "payload" in msg
120- assert msg ["shape" ] == [10 ]
121-
122- # Verify payload contains the expected array data
123- payload_array = np .frombuffer (msg ["payload" ], dtype = np .int64 )
124- expected_array = np .arange (10 ) + (
125- i + 1
126- ) # i+2 because we start from update 1
127- np .testing .assert_array_equal (payload_array , expected_array )
142+ received = receive_ws_updates (websocket , count = 2 )
143+ # Content starts with update #2
144+ verify_ws_updates (received , start = 2 )
128145
129146
130147def test_subscribe_after_first_update_from_beginning_websockets (
@@ -198,6 +215,121 @@ def test_subscribe_after_first_update_from_beginning_websockets(
198215 np .testing .assert_array_equal (payload_array , expected_array )
199216
200217
218+ @pytest .mark .parametrize ("write_op" , (overwrite_array , patch_array ))
219+ @pytest .mark .parametrize ("persist" , (None , True , False ))
220+ def test_updates_persist_write (tiled_websocket_context , write_op , persist ):
221+ context = tiled_websocket_context
222+ client = from_context (context )
223+ test_client = context .http_client
224+
225+ # Create streaming array node using Tiled client
226+ arr = np .arange (10 )
227+ streaming_node = client .write_array (arr , key = "test_stream_immediate" )
228+
229+ # Connect WebSocket using TestClient with msgpack format and authorization
230+ with test_client .websocket_connect (
231+ "/api/v1/stream/single/test_stream_immediate?envelope_format=msgpack" ,
232+ headers = {"Authorization" : "Apikey secret" },
233+ ) as websocket :
234+ # Send 3 updates using Tiled client that write values into the array
235+ send_ws_updates (streaming_node , write_op , count = 3 , persist = persist )
236+
237+ # Receive and validate all updates
238+ received = receive_ws_updates (websocket , count = 3 )
239+ verify_ws_updates (received )
240+
241+ # Verify values of persisted data
242+ if persist or persist is None :
243+ expected_persisted = np .arange (10 ) + 3 # Final sent values
244+ else :
245+ expected_persisted = arr # Original values
246+ persisted_data = streaming_node .read ()
247+ np .testing .assert_array_equal (persisted_data , expected_persisted )
248+
249+
250+ @pytest .mark .parametrize ("persist" , (None , True , False ))
251+ def test_updates_persist_write_block (tiled_websocket_context , persist ):
252+ context = tiled_websocket_context
253+ client = from_context (context )
254+ test_client = context .http_client
255+
256+ # Create a streaming chunked array node using Tiled client
257+ _arr = np .array ([np .arange (10 ) for _ in range (3 )])
258+ arr = dask .array .from_array (_arr , chunks = (1 , 10 )) # Chunk along first axis
259+ streaming_node = client .write_array (arr , key = "test_stream_immediate" )
260+
261+ # Connect WebSocket using TestClient with msgpack format and authorization
262+ with test_client .websocket_connect (
263+ "/api/v1/stream/single/test_stream_immediate?envelope_format=msgpack" ,
264+ headers = {"Authorization" : "Apikey secret" },
265+ ) as websocket :
266+ # Send 3 updates using Tiled client that write values into the array
267+ send_ws_updates (streaming_node , write_array_block , count = 3 , persist = persist )
268+
269+ # Receive and validate all updates
270+ received = receive_ws_updates (websocket , count = 3 )
271+ verify_ws_updates (received , chunked = True )
272+
273+ # Verify values of persisted data
274+ if persist or persist is None :
275+ # Combined effect of all sent values
276+ expected_persisted = np .array ([np .arange (10 ) + i for i in range (1 , 4 )])
277+ else :
278+ # Original values
279+ expected_persisted = arr
280+ persisted_data = streaming_node .read ()
281+ np .testing .assert_array_equal (persisted_data , expected_persisted )
282+
283+
284+ # Extending an array with persist=False is not yet supported
285+ @pytest .mark .parametrize ("persist" , (None , True ))
286+ def test_updates_persist_append (tiled_websocket_context , persist ):
287+ context = tiled_websocket_context
288+ client = from_context (context )
289+ test_client = context .http_client
290+
291+ # Create streaming array node using Tiled client
292+ arr = np .arange (10 )
293+ streaming_node = client .write_array (arr , key = "test_stream_immediate" )
294+
295+ # Connect WebSocket using TestClient with msgpack format and authorization
296+ with test_client .websocket_connect (
297+ "/api/v1/stream/single/test_stream_immediate?envelope_format=msgpack" ,
298+ headers = {"Authorization" : "Apikey secret" },
299+ ) as websocket :
300+ # Send 3 updates using Tiled client that append to the array
301+ send_ws_updates (streaming_node , append_array , count = 3 , persist = persist )
302+
303+ # Receive and validate all updates
304+ received = receive_ws_updates (websocket , count = 3 )
305+ verify_ws_updates (received )
306+
307+ # Verify values of persisted data
308+ if persist or persist is None :
309+ # Combined effect of all sent values
310+ expected_persisted = np .array (
311+ [np .arange (10 ) + i for i in range (0 , 4 )]
312+ ).flatten ()
313+ else :
314+ # Original values
315+ expected_persisted = arr
316+ persisted_data = streaming_node .read ()
317+ np .testing .assert_array_equal (persisted_data , expected_persisted )
318+
319+
320+ def test_updates_append_without_persist (tiled_websocket_context ):
321+ context = tiled_websocket_context
322+ client = from_context (context )
323+
324+ # Create streaming array node using Tiled client
325+ arr = np .arange (10 )
326+ streaming_node = client .write_array (arr , key = "test_stream_immediate" )
327+
328+ with pytest .raises (ValueError , match = "Cannot PATCH an array with both parameters" ):
329+ # Extending an array with persist=False is not yet supported
330+ send_ws_updates (streaming_node , append_array , count = 1 , persist = False )
331+
332+
201333def test_close_stream_success (tiled_websocket_context ):
202334 """Test successful close of an existing stream."""
203335 context = tiled_websocket_context
0 commit comments