Skip to content

Commit ccc9332

Browse files
authored
Merge pull request #866 from genematx/fix-recursion
Fix Recursion Errors in BlueskyRun v2 Client
2 parents 1b482f1 + fa1b292 commit ccc9332

File tree

1 file changed

+53
-40
lines changed

1 file changed

+53
-40
lines changed

bluesky-tiled-plugins/bluesky_tiled_plugins/bluesky_run.py

Lines changed: 53 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from datetime import datetime
88
from typing import Optional
99

10-
from tiled.client.composite import CompositeClient
1110
from tiled.client.container import Container
1211
from tiled.client.utils import handle_error
1312

@@ -131,6 +130,18 @@ def read(self):
131130
"Reading any entire run is not supported. Access a stream in this run and read that."
132131
)
133132

133+
@property
134+
def base(self):
135+
"Return the base Container client instead of a BlueskyRun client"
136+
return Container(
137+
self.context,
138+
item=self.item,
139+
structure_clients=self.structure_clients,
140+
queries=self._queries,
141+
sorting=self._sorting,
142+
include_data_sources=self._include_data_sources,
143+
)
144+
134145
to_dask = read
135146

136147

@@ -215,8 +226,29 @@ class _BlueskyRunSQL(BlueskyRun):
215226
base class for other classes (v2 and v3) that implement additional methods.
216227
"""
217228

218-
def __init__(self, *args, **kwargs):
219-
super().__init__(*args, **kwargs)
229+
@functools.cached_property
230+
def _has_streams_namespace(self) -> bool:
231+
"""Determine whether the BlueskyRun has an intermediate "streams" namespace.
232+
233+
Maintained for backward compatibility. Returns True if the following conditions are met:
234+
1. There is a "streams" key in the base container.
235+
2. The specs of the "streams" container do not include "BlueskyEventStream",
236+
indicating that "streams" is not itself a BlueskyEventStream.
237+
"""
238+
return ("streams" in self.base) and (
239+
"BlueskyEventStream" not in {s.name for s in self.base["streams"].specs}
240+
)
241+
242+
@functools.cached_property
243+
def _stream_names(self) -> list[str]:
244+
"""Get the sorted list of stream names in the BlueskyRun.
245+
246+
This property accounts for both the new layout (without "streams" namespace)
247+
and the old layout (with "streams" namespace), in which case the stream names
248+
are derived from the keys under the "streams" namespace.
249+
"""
250+
251+
return sorted(k for k in (self.base["streams"] if self._has_streams_namespace else self.base))
220252

221253
def __getitem__(self, key):
222254
if isinstance(key, tuple):
@@ -281,18 +313,18 @@ def _base_getitem(key):
281313
else:
282314
raise KeyError from e
283315

284-
@functools.cached_property
285-
def _has_streams_namespace(self):
286-
return ("streams" in self) and ("BlueskyEventStream" not in {s.name for s in self["streams"].specs})
316+
def _keys_slice(self, start, stop, direction, page_size: Optional[int] = None, **kwargs):
317+
sorted_keys = reversed(self._stream_names) if direction < 0 else self._stream_names
318+
return (yield from sorted_keys[start:stop])
287319

288-
@functools.cached_property
289-
def _stream_names(self):
290-
# Access to the "streams" namespace (possibly a separate container)
291-
if self._has_streams_namespace:
292-
return sorted(k for k in self["streams"])
293-
else:
294-
# No intermediate "streams" node, use the top-level node
295-
return sorted(k for k in self)
320+
def _items_slice(self, start, stop, direction, page_size: Optional[int] = None, **kwargs):
321+
sorted_keys = reversed(self._stream_names) if direction < 0 else self._stream_names
322+
for key in sorted_keys[start:stop]:
323+
yield key, self[key]
324+
return
325+
326+
def __iter__(self):
327+
yield from self._stream_names
296328

297329
def documents(self, fill=False):
298330
with io.BytesIO() as buffer:
@@ -304,43 +336,24 @@ def documents(self, fill=False):
304336

305337

306338
class BlueskyRunV2SQL(BlueskyRunV2, _BlueskyRunSQL):
307-
def _keys_slice(self, start, stop, direction, page_size: Optional[int] = None, **kwargs):
308-
if self._has_streams_namespace:
309-
keys = reversed(self._stream_names) if direction < 0 else self._stream_names
310-
return (yield from keys[start:stop])
311-
else:
312-
return (yield from super()._keys_slice(start, stop, direction, page_size=page_size, **kwargs))
313-
314-
def _items_slice(self, start, stop, direction, page_size: Optional[int] = None, **kwargs):
315-
if self._has_streams_namespace:
316-
_streams_node = super().get("streams", {})
317-
for key in reversed(self._stream_names) if direction < 0 else self._stream_names:
318-
yield key, _streams_node.get(key)
319-
return
320-
else:
321-
return (yield from super()._items_slice(start, stop, direction, page_size=page_size, **kwargs))
322-
323339
def __getitem__(self, key):
324-
# For v3, we need to handle the streams and configs keys specially
325-
if key == "streams":
326-
return super().__getitem__("streams")
327-
340+
# For v2, we need to handle the streams and configs keys specially
328341
if isinstance(key, tuple):
329342
key = "/".join(key)
330343

331344
key, *rest = key.split("/", 1)
332-
stream_composite_client = super().__getitem__(key)
333-
if not isinstance(stream_composite_client, CompositeClient):
345+
346+
if key == "streams":
334347
raise KeyError(
335348
"Looks like you are trying to access the 'streams' namespace, "
336349
"but this pathway has never been supported in the .v2 BlueskyRun client. "
337350
"Please access the stream directly, e.g. run['primary']."
338351
)
352+
353+
stream_composite_client = super().__getitem__(key)
339354
stream_container = BlueskyEventStreamV2SQL.from_stream_client(stream_composite_client)
340-
return stream_container[rest] if rest else stream_container
341355

342-
def __iter__(self):
343-
yield from self._stream_names
356+
return stream_container[rest[0]] if rest else stream_container
344357

345358

346359
class BlueskyRunV3(_BlueskyRunSQL):
@@ -356,8 +369,8 @@ def __new__(cls, context, *, item, structure_clients, **kwargs):
356369
return BlueskyRunV2Mongo(context, item=item, structure_clients=structure_clients, **kwargs)
357370

358371
def __getattr__(self, key):
372+
# A shortcut to the stream data
359373
if key in self._stream_names:
360-
# A shortcut to the stream data
361374
return self["streams"][key] if self._has_streams_namespace else self[key]
362375

363376
return super().__getattr__(key)

0 commit comments

Comments
 (0)