Skip to content

Commit a5d0040

Browse files
authored
Merge pull request #872 from genematx/internal-arrays
Internal Arrays as zarr
2 parents f8ca1d0 + ef9738e commit a5d0040

File tree

5 files changed

+225
-32
lines changed

5 files changed

+225
-32
lines changed

bluesky-tiled-plugins/bluesky_tiled_plugins/consolidators.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,37 @@ class Patch:
1818
shape: tuple[int, ...]
1919
offset: tuple[int, ...]
2020

21+
@classmethod
22+
def combine_patches(cls, patches: list["Patch"]) -> "Patch":
23+
"""Combine multiple patches into a single patch
24+
25+
The combined patch covers the union (smallest bounding box) of all provided patches.
26+
27+
Parameters
28+
----------
29+
patches : list[Patch]
30+
A list of Patch objects to combine.
31+
32+
Returns
33+
-------
34+
Patch
35+
A new Patch object that covers the union of all input patches.
36+
"""
37+
38+
# Determine the overall shape and offset
39+
min_offset = list(patches[0].offset)
40+
max_extent = [offset + size for offset, size in zip(patches[0].offset, patches[0].shape)]
41+
42+
for patch in patches[1:]:
43+
for i in range(len(min_offset)):
44+
min_offset[i] = min(min_offset[i], patch.offset[i])
45+
max_extent[i] = max(max_extent[i], patch.offset[i] + patch.shape[i])
46+
47+
combined_shape = tuple(max_e - min_o for min_o, max_e in zip(min_offset, max_extent))
48+
combined_offset = tuple(min_offset)
49+
50+
return cls(shape=combined_shape, offset=combined_offset)
51+
2152

2253
class ConsolidatorBase:
2354
"""Consolidator of StreamDatums

bluesky-tiled-plugins/bluesky_tiled_plugins/tiled_writer.py

Lines changed: 78 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, Callable, Optional, Union, cast
77
from warnings import warn
88

9+
import numpy
910
import pyarrow
1011
from bluesky.callbacks.core import CallbackBase
1112
from bluesky.callbacks.json_writer import JSONLinesWriter
@@ -34,6 +35,7 @@
3435
from event_model.documents.event_descriptor import DataKey
3536
from event_model.documents.stream_datum import StreamRange
3637
from tiled.client import from_profile, from_uri
38+
from tiled.client.array import ArrayClient
3739
from tiled.client.base import BaseClient
3840
from tiled.client.container import Container
3941
from tiled.client.dataframe import DataFrameClient
@@ -46,6 +48,10 @@
4648
# Aggregate the Event table rows and StreamDatums in batches before writing to Tiled
4749
BATCH_SIZE = 10000
4850

51+
# Maximum size of internal arrays from Event docs to write to tabular (SQL) storage; larger arrays will be written
52+
# as zarr. Set to 0 to write all internal arrays as zarr, and -1 to write all internal arrays to tabular storage.
53+
MAX_ARRAY_SIZE = 16
54+
4955
# Disallow using reserved words as data_keys identifiers
5056
# Related: https://github.com/bluesky/event-model/pull/223
5157
RESERVED_DATA_KEYS = ["time", "seq_num"]
@@ -523,25 +529,44 @@ class _RunWriter(CallbackBase):
523529
The Tiled client to use for writing the data.
524530
"""
525531

526-
def __init__(self, client: BaseClient, batch_size: int = BATCH_SIZE):
532+
def __init__(self, client: BaseClient, batch_size: int = BATCH_SIZE, max_array_size: int = MAX_ARRAY_SIZE):
527533
self.client = client
528534
self.root_node: Union[None, Container] = None
529535
self._desc_nodes: dict[str, Container] = {} # references to the descriptor nodes by their uid's and names
530536
self._sres_nodes: dict[str, BaseClient] = {}
531537
self._internal_tables: dict[str, DataFrameClient] = {} # references to the internal tables by desc_names
538+
self._internal_arrays: dict[str, ArrayClient] = {} # refs to the internal arrays by desc_name/data_key
532539
self._stream_resource_cache: dict[str, StreamResource] = {}
533540
self._consolidators: dict[str, ConsolidatorBase] = {}
534541
self._internal_data_cache: dict[str, list[dict[str, Any]]] = defaultdict(list)
535542
self._external_data_cache: dict[str, StreamDatum] = {} # sres_uid : (concatenated) StreamDatum
536-
self._batch_size = batch_size
543+
self._int_array_keys: dict[str, set[str]] = defaultdict(set) # data_keys with array data by desc_name
544+
self._batch_size: int = batch_size
545+
self._max_array_size: int = max_array_size # Max size of arrays to write to tabular storage
537546
self.data_keys: dict[str, DataKey] = {}
538-
self.access_tags = None
547+
self.access_tags: Optional[list[str]] = None
539548

540549
def _write_internal_data(self, data_cache: list[dict[str, Any]], desc_node: Container):
541550
"""Write the internal data table to Tiled and clear the cache."""
542551

543552
desc_name = desc_node.item["id"] # Name of the descriptor (stream)
544-
table = pyarrow.Table.from_pylist(data_cache)
553+
# 1. Write internal array data, if any; remove it from the tabular data
554+
for key in self._int_array_keys[desc_name]:
555+
array = numpy.array([row.pop(key) for row in data_cache if key in row])
556+
if not (arr_client := self._internal_arrays.get(f"{desc_name}/{key}")):
557+
# Create a new "internal" array data node and write the initial piece of data
558+
metadata = truncate_json_overflow(self.data_keys.get(key, {}))
559+
dims = ("time",) + tuple(f"dim_{i}" for i in range(1, array.ndim))
560+
arr_client = desc_node.write_array(
561+
array, key=key, metadata=metadata, dims=dims, access_tags=self.access_tags
562+
)
563+
self._internal_arrays[f"{desc_name}/{key}"] = arr_client
564+
else:
565+
arr_client.patch(array, offset=arr_client.shape[:1], extend=True)
566+
567+
# 2. Write internal tabular data; all data_keys for arrays have been removed from data_cache on step 1
568+
if not (table := pyarrow.Table.from_pylist(data_cache)):
569+
return # Nothing to write
545570

546571
if not (df_client := self._internal_tables.get(desc_name)):
547572
# Create a new "internal" data node and write the initial piece of data
@@ -562,18 +587,18 @@ def _write_internal_data(self, data_cache: list[dict[str, Any]], desc_node: Cont
562587

563588
df_client.append_partition(0, table)
564589

565-
def _write_external_data(self, doc: StreamDatum):
566-
"""Register the external data provided in StreamDatum in Tiled"""
590+
def _update_consolidator(self, doc: StreamDatum):
591+
"""Register the external data from StreamDatum in the Consolidator"""
567592

568593
sres_uid, desc_uid = doc["stream_resource"], doc["descriptor"]
569594
sres_node, consolidator = self.get_sres_node(sres_uid, desc_uid)
570595
patch = consolidator.consume_stream_datum(doc)
571-
self._update_data_source_for_node(sres_node, consolidator.get_data_source(), patch)
596+
return sres_node, consolidator, patch
572597

573598
def _update_data_source_for_node(
574599
self, node: BaseClient, data_source: DataSource, patch: Optional[Patch] = None
575600
):
576-
"""Update StreamResource node in Tiled"""
601+
"""Update DataSource of the node in Tiled corresponding to the StreamResource"""
577602
data_source.id = node.data_sources()[0].id # ID of the existing DataSource record
578603
handle_error(
579604
node.context.http_client.put(
@@ -588,6 +613,12 @@ def _update_data_source_for_node(
588613
)
589614
).json()
590615

616+
def _write_external_data(self, doc: StreamDatum):
617+
"""Register (or update) the external data from StreamDatum in Tiled"""
618+
619+
sres_node, consolidator, patch = self._update_consolidator(doc)
620+
self._update_data_source_for_node(sres_node, consolidator.get_data_source(), patch)
621+
591622
def start(self, doc: RunStart):
592623
doc = copy.copy(doc)
593624
self.access_tags = doc.pop("tiled_access_tags", None) # type: ignore
@@ -608,23 +639,40 @@ def stop(self, doc: RunStop):
608639
self._write_internal_data(data_cache, desc_node=self._desc_nodes[desc_name])
609640
data_cache.clear()
610641

611-
# Write the cached StreamDatums data
642+
# Write the cached StreamDatums data.
643+
# Only update the data_source _once_ per each StreamResource node, even if consuming multiple StreamDatums.
644+
updated_node_and_cons: dict[tuple[BaseClient, ConsolidatorBase], list[Patch]] = defaultdict(list)
612645
for stream_datum_doc in self._external_data_cache.values():
613-
self._write_external_data(stream_datum_doc)
614-
615-
# Validate structure for some StreamResource nodes
616-
for sres_uid, sres_node in self._sres_nodes.items():
617-
consolidator = self._consolidators[sres_uid]
646+
sres_node, consolidator, patch = self._update_consolidator(stream_datum_doc)
647+
updated_node_and_cons[(sres_node, consolidator)].append(patch)
648+
for (sres_node, consolidator), patches in updated_node_and_cons.items():
649+
final_patch = Patch.combine_patches(patches)
650+
self._update_data_source_for_node(sres_node, consolidator.get_data_source(), patch=final_patch)
651+
652+
# Validate structure for some StreamResource nodes, select unique pairs of (sres_node, consolidator)
653+
notes = []
654+
node_and_cons = {
655+
(sres_node, self._consolidators[sres_uid]) for sres_uid, sres_node in self._sres_nodes.items()
656+
}
657+
for sres_node, consolidator in node_and_cons:
618658
if consolidator._sres_parameters.get("_validate", False):
659+
title = f"Validation of data key '{sres_node.item['id']}'"
619660
try:
620-
consolidator.validate(fix_errors=True)
661+
_notes = consolidator.validate(fix_errors=True)
662+
notes.extend([title + ": " + note for note in _notes])
621663
except Exception as e:
622664
msg = f"{type(e).__name__}: " + str(e).replace("\n", " ").replace("\r", "").strip()
623-
warn(f"Validation of StreamResource {sres_uid} failed with error: {msg}", stacklevel=2)
665+
msg = title + f" failed with error: {msg}"
666+
warn(msg, stacklevel=2)
667+
notes.append(msg)
624668
self._update_data_source_for_node(sres_node, consolidator.get_data_source())
625669

626670
# Write the stop document to the metadata
627-
self.root_node.update_metadata(metadata={"stop": doc, **dict(self.root_node.metadata)}, drop_revision=True)
671+
for key in self._internal_arrays.keys():
672+
notes.append(f"Internal array data in '{key}' written as zarr format.")
673+
notes = doc.pop("_run_normalizer_notes", []) + notes # Retrieve notes from the normalizer, if any
674+
md_update = {"stop": doc, **({"notes": notes} if notes else {})}
675+
self.root_node.update_metadata(metadata=md_update, drop_revision=True)
628676

629677
def descriptor(self, doc: EventDescriptor):
630678
desc_name = doc["name"] # Name of the descriptor/stream
@@ -641,6 +689,15 @@ def descriptor(self, doc: EventDescriptor):
641689
specs=[Spec("BlueskyEventStream", version="3.0"), Spec("composite")],
642690
access_tags=self.access_tags,
643691
).base
692+
693+
# Keep track of data_keys for internal array data to be written as zarr, if any
694+
for key, val in doc.get("data_keys", {}).items():
695+
if (
696+
("external" not in val.keys())
697+
and (val.get("dtype") == "array")
698+
and (0 <= self._max_array_size < sum(val.get("shape", [])))
699+
):
700+
self._int_array_keys[desc_name].add(key)
644701
else:
645702
# Rare Case: This new descriptor likely updates stream configs mid-experiment
646703
# We assume tha the full descriptor has been already received, so we don't need to store everything
@@ -693,7 +750,7 @@ def get_sres_node(self, sres_uid: str, desc_uid: Optional[str] = None) -> tuple[
693750

694751
elif sres_uid in self._stream_resource_cache.keys():
695752
if not desc_uid:
696-
raise RuntimeError("Descriptor uid must be specified to initialise a Stream Resource node")
753+
raise RuntimeError("Descriptor uid must be specified to initialize a Stream Resource node")
697754

698755
# Define `full_data_key` as desc_name + _ + data_key to ensure uniqueness across streams
699756
sres_doc = self._stream_resource_cache[sres_uid]
@@ -796,6 +853,7 @@ def __init__(
796853
spec_to_mimetype: Optional[dict[str, str]] = None,
797854
backup_directory: Optional[str] = None,
798855
batch_size: int = BATCH_SIZE,
856+
max_array_size: int = MAX_ARRAY_SIZE,
799857
):
800858
self.client = client.include_data_sources()
801859
self.patches = patches or {}
@@ -804,10 +862,11 @@ def __init__(
804862
self._normalizer = normalizer
805863
self._run_router = RunRouter([self._factory])
806864
self._batch_size = batch_size
865+
self._max_array_size = max_array_size
807866

808867
def _factory(self, name, doc):
809868
"""Factory method to create a callback for writing a single run into Tiled."""
810-
cb = run_writer = _RunWriter(self.client, batch_size=self._batch_size)
869+
cb = run_writer = _RunWriter(self.client, batch_size=self._batch_size, max_array_size=self._max_array_size)
811870

812871
if self._normalizer:
813872
# If normalize is True, create a RunNormalizer callback to update documents to the latest schema

bluesky-tiled-plugins/tests/examples/internal_events.json

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,20 @@
8585
"dtype": "array",
8686
"shape": [],
8787
"object_name": "det"
88+
},
89+
"long": {
90+
"source": "SIM:det",
91+
"dtype": "array",
92+
"shape": [8],
93+
"object_name": "det"
8894
}
8995
},
9096
"name": "primary",
9197
"object_keys": {
9298
"det": [
9399
"det",
94-
"empty"
100+
"empty",
101+
"long"
95102
]
96103
},
97104
"run_start": "{{ uuid }}-dc0518f354ee",
@@ -101,7 +108,8 @@
101108
"det": {
102109
"fields": [
103110
"det",
104-
"empty"
111+
"empty",
112+
"long"
105113
]
106114
}
107115
}
@@ -114,11 +122,13 @@
114122
"time": 1745338531.215753,
115123
"data": {
116124
"det": 1.0,
117-
"empty": []
125+
"empty": [],
126+
"long": [0, 1, 2, 3, 4, 5, 6, 7]
118127
},
119128
"timestamps": {
120129
"det": 1745338531.001646,
121-
"empty": 1745338531.001647
130+
"empty": 1745338531.001647,
131+
"long": 1745338531.001648
122132
},
123133
"seq_num": 1,
124134
"filled": {},
@@ -132,11 +142,13 @@
132142
"time": 1745338531.218661,
133143
"data": {
134144
"det": 1.0,
135-
"empty": []
145+
"empty": [],
146+
"long": [10, 11, 12, 13, 14, 15, 16, 17]
136147
},
137148
"timestamps": {
138149
"det": 1745338531.217144,
139-
"empty": 1745338531.217145
150+
"empty": 1745338531.217145,
151+
"long": 1745338531.217146
140152
},
141153
"seq_num": 2,
142154
"filled": {},
@@ -214,13 +226,20 @@
214226
"dtype": "array",
215227
"shape": [],
216228
"object_name": "det"
229+
},
230+
"long": {
231+
"source": "SIM:det",
232+
"dtype": "array",
233+
"shape": [8],
234+
"object_name": "det"
217235
}
218236
},
219237
"name": "primary",
220238
"object_keys": {
221239
"det": [
222240
"det",
223-
"empty"
241+
"empty",
242+
"long"
224243
]
225244
},
226245
"run_start": "{{ uuid }}-dc0518f354ee",
@@ -230,7 +249,8 @@
230249
"det": {
231250
"fields": [
232251
"det",
233-
"empty"
252+
"empty",
253+
"long"
234254
]
235255
}
236256
}
@@ -243,11 +263,13 @@
243263
"time": 1745338533.2213702,
244264
"data": {
245265
"det": 1.0,
246-
"empty": []
266+
"empty": [],
267+
"long": [20, 21, 22, 23, 24, 25, 26, 27]
247268
},
248269
"timestamps": {
249270
"det": 1745338533.219767,
250-
"empty": 1745338533.219768
271+
"empty": 1745338533.219768,
272+
"long": 1745338533.219769
251273
},
252274
"seq_num": 3,
253275
"filled": {},

0 commit comments

Comments
 (0)