66from typing import Any , Callable , Optional , Union , cast
77from warnings import warn
88
9+ import numpy
910import pyarrow
1011from bluesky .callbacks .core import CallbackBase
1112from bluesky .callbacks .json_writer import JSONLinesWriter
3435from event_model .documents .event_descriptor import DataKey
3536from event_model .documents .stream_datum import StreamRange
3637from tiled .client import from_profile , from_uri
38+ from tiled .client .array import ArrayClient
3739from tiled .client .base import BaseClient
3840from tiled .client .container import Container
3941from tiled .client .dataframe import DataFrameClient
4648# Aggregate the Event table rows and StreamDatums in batches before writing to Tiled
4749BATCH_SIZE = 10000
4850
51+ # Maximum size of internal arrays from Event docs to write to tabular (SQL) storage; larger arrays will be written
52+ # as zarr. Set to 0 to write all internal arrays as zarr, and -1 to write all internal arrays to tabular storage.
53+ MAX_ARRAY_SIZE = 16
54+
4955# Disallow using reserved words as data_keys identifiers
5056# Related: https://github.com/bluesky/event-model/pull/223
5157RESERVED_DATA_KEYS = ["time" , "seq_num" ]
@@ -523,25 +529,44 @@ class _RunWriter(CallbackBase):
523529 The Tiled client to use for writing the data.
524530 """
525531
526- def __init__ (self , client : BaseClient , batch_size : int = BATCH_SIZE ):
532+ def __init__ (self , client : BaseClient , batch_size : int = BATCH_SIZE , max_array_size : int = MAX_ARRAY_SIZE ):
527533 self .client = client
528534 self .root_node : Union [None , Container ] = None
529535 self ._desc_nodes : dict [str , Container ] = {} # references to the descriptor nodes by their uid's and names
530536 self ._sres_nodes : dict [str , BaseClient ] = {}
531537 self ._internal_tables : dict [str , DataFrameClient ] = {} # references to the internal tables by desc_names
538+ self ._internal_arrays : dict [str , ArrayClient ] = {} # refs to the internal arrays by desc_name/data_key
532539 self ._stream_resource_cache : dict [str , StreamResource ] = {}
533540 self ._consolidators : dict [str , ConsolidatorBase ] = {}
534541 self ._internal_data_cache : dict [str , list [dict [str , Any ]]] = defaultdict (list )
535542 self ._external_data_cache : dict [str , StreamDatum ] = {} # sres_uid : (concatenated) StreamDatum
536- self ._batch_size = batch_size
543+ self ._int_array_keys : dict [str , set [str ]] = defaultdict (set ) # data_keys with array data by desc_name
544+ self ._batch_size : int = batch_size
545+ self ._max_array_size : int = max_array_size # Max size of arrays to write to tabular storage
537546 self .data_keys : dict [str , DataKey ] = {}
538- self .access_tags = None
547+ self .access_tags : Optional [ list [ str ]] = None
539548
540549 def _write_internal_data (self , data_cache : list [dict [str , Any ]], desc_node : Container ):
541550 """Write the internal data table to Tiled and clear the cache."""
542551
543552 desc_name = desc_node .item ["id" ] # Name of the descriptor (stream)
544- table = pyarrow .Table .from_pylist (data_cache )
553+ # 1. Write internal array data, if any; remove it from the tabular data
554+ for key in self ._int_array_keys [desc_name ]:
555+ array = numpy .array ([row .pop (key ) for row in data_cache if key in row ])
556+ if not (arr_client := self ._internal_arrays .get (f"{ desc_name } /{ key } " )):
557+ # Create a new "internal" array data node and write the initial piece of data
558+ metadata = truncate_json_overflow (self .data_keys .get (key , {}))
559+ dims = ("time" ,) + tuple (f"dim_{ i } " for i in range (1 , array .ndim ))
560+ arr_client = desc_node .write_array (
561+ array , key = key , metadata = metadata , dims = dims , access_tags = self .access_tags
562+ )
563+ self ._internal_arrays [f"{ desc_name } /{ key } " ] = arr_client
564+ else :
565+ arr_client .patch (array , offset = arr_client .shape [:1 ], extend = True )
566+
567+ # 2. Write internal tabular data; all data_keys for arrays have been removed from data_cache on step 1
568+ if not (table := pyarrow .Table .from_pylist (data_cache )):
569+ return # Nothing to write
545570
546571 if not (df_client := self ._internal_tables .get (desc_name )):
547572 # Create a new "internal" data node and write the initial piece of data
@@ -562,18 +587,18 @@ def _write_internal_data(self, data_cache: list[dict[str, Any]], desc_node: Cont
562587
563588 df_client .append_partition (0 , table )
564589
565- def _write_external_data (self , doc : StreamDatum ):
566- """Register the external data provided in StreamDatum in Tiled """
590+ def _update_consolidator (self , doc : StreamDatum ):
591+ """Register the external data from StreamDatum in the Consolidator """
567592
568593 sres_uid , desc_uid = doc ["stream_resource" ], doc ["descriptor" ]
569594 sres_node , consolidator = self .get_sres_node (sres_uid , desc_uid )
570595 patch = consolidator .consume_stream_datum (doc )
571- self . _update_data_source_for_node ( sres_node , consolidator . get_data_source () , patch )
596+ return sres_node , consolidator , patch
572597
573598 def _update_data_source_for_node (
574599 self , node : BaseClient , data_source : DataSource , patch : Optional [Patch ] = None
575600 ):
576- """Update StreamResource node in Tiled"""
601+ """Update DataSource of the node in Tiled corresponding to the StreamResource """
577602 data_source .id = node .data_sources ()[0 ].id # ID of the existing DataSource record
578603 handle_error (
579604 node .context .http_client .put (
@@ -588,6 +613,12 @@ def _update_data_source_for_node(
588613 )
589614 ).json ()
590615
616+ def _write_external_data (self , doc : StreamDatum ):
617+ """Register (or update) the external data from StreamDatum in Tiled"""
618+
619+ sres_node , consolidator , patch = self ._update_consolidator (doc )
620+ self ._update_data_source_for_node (sres_node , consolidator .get_data_source (), patch )
621+
591622 def start (self , doc : RunStart ):
592623 doc = copy .copy (doc )
593624 self .access_tags = doc .pop ("tiled_access_tags" , None ) # type: ignore
@@ -608,23 +639,40 @@ def stop(self, doc: RunStop):
608639 self ._write_internal_data (data_cache , desc_node = self ._desc_nodes [desc_name ])
609640 data_cache .clear ()
610641
611- # Write the cached StreamDatums data
642+ # Write the cached StreamDatums data.
643+ # Only update the data_source _once_ per each StreamResource node, even if consuming multiple StreamDatums.
644+ updated_node_and_cons : dict [tuple [BaseClient , ConsolidatorBase ], list [Patch ]] = defaultdict (list )
612645 for stream_datum_doc in self ._external_data_cache .values ():
613- self ._write_external_data (stream_datum_doc )
614-
615- # Validate structure for some StreamResource nodes
616- for sres_uid , sres_node in self ._sres_nodes .items ():
617- consolidator = self ._consolidators [sres_uid ]
646+ sres_node , consolidator , patch = self ._update_consolidator (stream_datum_doc )
647+ updated_node_and_cons [(sres_node , consolidator )].append (patch )
648+ for (sres_node , consolidator ), patches in updated_node_and_cons .items ():
649+ final_patch = Patch .combine_patches (patches )
650+ self ._update_data_source_for_node (sres_node , consolidator .get_data_source (), patch = final_patch )
651+
652+ # Validate structure for some StreamResource nodes, select unique pairs of (sres_node, consolidator)
653+ notes = []
654+ node_and_cons = {
655+ (sres_node , self ._consolidators [sres_uid ]) for sres_uid , sres_node in self ._sres_nodes .items ()
656+ }
657+ for sres_node , consolidator in node_and_cons :
618658 if consolidator ._sres_parameters .get ("_validate" , False ):
659+ title = f"Validation of data key '{ sres_node .item ['id' ]} '"
619660 try :
620- consolidator .validate (fix_errors = True )
661+ _notes = consolidator .validate (fix_errors = True )
662+ notes .extend ([title + ": " + note for note in _notes ])
621663 except Exception as e :
622664 msg = f"{ type (e ).__name__ } : " + str (e ).replace ("\n " , " " ).replace ("\r " , "" ).strip ()
623- warn (f"Validation of StreamResource { sres_uid } failed with error: { msg } " , stacklevel = 2 )
665+ msg = title + f" failed with error: { msg } "
666+ warn (msg , stacklevel = 2 )
667+ notes .append (msg )
624668 self ._update_data_source_for_node (sres_node , consolidator .get_data_source ())
625669
626670 # Write the stop document to the metadata
627- self .root_node .update_metadata (metadata = {"stop" : doc , ** dict (self .root_node .metadata )}, drop_revision = True )
671+ for key in self ._internal_arrays .keys ():
672+ notes .append (f"Internal array data in '{ key } ' written as zarr format." )
673+ notes = doc .pop ("_run_normalizer_notes" , []) + notes # Retrieve notes from the normalizer, if any
674+ md_update = {"stop" : doc , ** ({"notes" : notes } if notes else {})}
675+ self .root_node .update_metadata (metadata = md_update , drop_revision = True )
628676
629677 def descriptor (self , doc : EventDescriptor ):
630678 desc_name = doc ["name" ] # Name of the descriptor/stream
@@ -641,6 +689,15 @@ def descriptor(self, doc: EventDescriptor):
641689 specs = [Spec ("BlueskyEventStream" , version = "3.0" ), Spec ("composite" )],
642690 access_tags = self .access_tags ,
643691 ).base
692+
693+ # Keep track of data_keys for internal array data to be written as zarr, if any
694+ for key , val in doc .get ("data_keys" , {}).items ():
695+ if (
696+ ("external" not in val .keys ())
697+ and (val .get ("dtype" ) == "array" )
698+ and (0 <= self ._max_array_size < sum (val .get ("shape" , [])))
699+ ):
700+ self ._int_array_keys [desc_name ].add (key )
644701 else :
645702 # Rare Case: This new descriptor likely updates stream configs mid-experiment
646703 # We assume tha the full descriptor has been already received, so we don't need to store everything
@@ -693,7 +750,7 @@ def get_sres_node(self, sres_uid: str, desc_uid: Optional[str] = None) -> tuple[
693750
694751 elif sres_uid in self ._stream_resource_cache .keys ():
695752 if not desc_uid :
696- raise RuntimeError ("Descriptor uid must be specified to initialise a Stream Resource node" )
753+ raise RuntimeError ("Descriptor uid must be specified to initialize a Stream Resource node" )
697754
698755 # Define `full_data_key` as desc_name + _ + data_key to ensure uniqueness across streams
699756 sres_doc = self ._stream_resource_cache [sres_uid ]
@@ -796,6 +853,7 @@ def __init__(
796853 spec_to_mimetype : Optional [dict [str , str ]] = None ,
797854 backup_directory : Optional [str ] = None ,
798855 batch_size : int = BATCH_SIZE ,
856+ max_array_size : int = MAX_ARRAY_SIZE ,
799857 ):
800858 self .client = client .include_data_sources ()
801859 self .patches = patches or {}
@@ -804,10 +862,11 @@ def __init__(
804862 self ._normalizer = normalizer
805863 self ._run_router = RunRouter ([self ._factory ])
806864 self ._batch_size = batch_size
865+ self ._max_array_size = max_array_size
807866
808867 def _factory (self , name , doc ):
809868 """Factory method to create a callback for writing a single run into Tiled."""
810- cb = run_writer = _RunWriter (self .client , batch_size = self ._batch_size )
869+ cb = run_writer = _RunWriter (self .client , batch_size = self ._batch_size , max_array_size = self . _max_array_size )
811870
812871 if self ._normalizer :
813872 # If normalize is True, create a RunNormalizer callback to update documents to the latest schema
0 commit comments