Skip to content

Commit a9bd4c4

Browse files
authored
Merge pull request #868 from genematx/fix-dims
Keep Dimension Names
2 parents dac83ab + f82c34d commit a9bd4c4

File tree

4 files changed

+70
-15
lines changed

4 files changed

+70
-15
lines changed

bluesky-tiled-plugins/bluesky_tiled_plugins/consolidators.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ def __init__(self, stream_resource: StreamResource, descriptor: EventDescriptor)
140140
self._num_rows: int = 0 # Number of rows in the Data Source (all rows, includung skips)
141141
self._seqnums_to_indices_map: dict[int, int] = {}
142142

143+
# Set the dimension names if provided
144+
self.dims: tuple[str, ...] = tuple(data_desc.get("dims", ()))
145+
143146
@classmethod
144147
def get_supported_mimetype(cls, sres):
145148
if sres["mimetype"] not in cls.supported_mimetypes:
@@ -241,6 +244,7 @@ def structure(self) -> ArrayStructure:
241244
data_type=self.data_type,
242245
shape=self.shape,
243246
chunks=self.chunks,
247+
dims=self.dims if self.dims else None,
244248
)
245249

246250
def consume_stream_datum(self, doc: StreamDatum):
@@ -300,22 +304,24 @@ def update_from_stream_resource(self, stream_resource: StreamResource):
300304

301305
raise NotImplementedError("This method is not implemented in the base Consolidator class.")
302306

303-
def validate(self, adapters_by_mimetype=None, fix_errors=False):
307+
def validate(self, adapters_by_mimetype=None, fix_errors=False) -> list[str]:
304308
"""Validate the Consolidator's state against the expected structure"""
305309

306310
# User-provided adapters take precedence over defaults.
307311
all_adapters_by_mimetype = collections.ChainMap((adapters_by_mimetype or {}), DEFAULT_ADAPTERS_BY_MIMETYPE)
308312
adapter_class = all_adapters_by_mimetype[self.mimetype]
309313

310314
# Initialize adapter from uris and determine the structure
311-
uris = [asset.data_uri for asset in self.assets if asset.parameter == "data_uris"]
315+
uris = [asset.data_uri for asset in self.assets]
312316
structure = adapter_class.from_uris(*uris, **self.adapter_parameters()).structure()
317+
notes = []
313318

314319
if self.shape != structure.shape:
315320
if not fix_errors:
316321
raise ValueError(f"Shape mismatch: {self.shape} != {structure.shape}")
317322
else:
318-
warnings.warn(f"Fixing shape mismatch: {self.shape} -> {structure.shape}", stacklevel=2)
323+
msg = f"Fixed shape mismatch: {self.shape} -> {structure.shape}"
324+
warnings.warn(msg, stacklevel=2)
319325
if self.join_method == "stack":
320326
self._num_rows = structure.shape[0]
321327
self.datum_shape = structure.shape[1:]
@@ -324,26 +330,53 @@ def validate(self, adapters_by_mimetype=None, fix_errors=False):
324330
multiplier = 1 if structure.shape[0] % structure.chunks[0][0] else structure.chunks[0][0]
325331
self._num_rows = structure.shape[0] // multiplier
326332
self.datum_shape = (multiplier,) + structure.shape[1:]
333+
notes.append(msg)
327334

328335
if self.chunks != structure.chunks:
329336
if not fix_errors:
330337
raise ValueError(f"Chunk shape mismatch: {self.chunks} != {structure.chunks}")
331338
else:
332339
_chunk_shape = tuple(c[0] for c in structure.chunks)
333-
warnings.warn(f"Fixing chunk shape mismatch: {self.chunk_shape} -> {_chunk_shape}", stacklevel=2)
340+
msg = f"Fixed chunk shape mismatch: {self.chunk_shape} -> {_chunk_shape}"
341+
warnings.warn(msg, stacklevel=2)
334342
self.chunk_shape = _chunk_shape
343+
notes.append(msg)
335344

336345
if self.data_type != structure.data_type:
337346
if not fix_errors:
338347
raise ValueError(f"dtype mismatch: {self.data_type} != {structure.data_type}")
339348
else:
340-
warnings.warn(
341-
f"Fixing dtype mismatch: {self.data_type.to_numpy_dtype()} -> {structure.data_type.to_numpy_dtype()}", # noqa
342-
stacklevel=2,
349+
msg = (
350+
f"Fixed dtype mismatch: {self.data_type.to_numpy_dtype()} "
351+
f"-> {structure.data_type.to_numpy_dtype()}"
343352
)
353+
warnings.warn(msg, stacklevel=2)
344354
self.data_type = structure.data_type
355+
notes.append(msg)
356+
357+
if self.dims and (len(self.dims) != len(structure.shape)):
358+
if not fix_errors:
359+
raise ValueError(
360+
f"Number of dimension names mismatch for a "
361+
f"{len(structure.shape)}-dimensional array: {self.dims}"
362+
)
363+
else:
364+
old_dims = self.dims
365+
if len(old_dims) < len(structure.shape):
366+
self.dims = (
367+
("time",)
368+
+ old_dims
369+
+ tuple(f"dim{i}" for i in range(len(old_dims) + 1, len(structure.shape)))
370+
)
371+
else:
372+
self.dims = old_dims[: len(structure.shape)]
373+
msg = f"Fixed dimension names: {old_dims} -> {self.dims}"
374+
warnings.warn(msg, stacklevel=2)
375+
notes.append(msg)
376+
377+
assert self.get_adapter() is not None, "Adapter can not be initialized"
345378

346-
assert self.get_adapter() is not None, "Adapter can not not initialized"
379+
return notes
347380

348381

349382
class CSVConsolidator(ConsolidatorBase):

bluesky-tiled-plugins/bluesky_tiled_plugins/tiled_writer.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def __init__(
198198
self._emitted: set[str] = set() # UIDs of the StreamResource documents that have been emitted
199199
self._int_keys: set[str] = set() # Names of internal data_keys
200200
self._ext_keys: set[str] = set()
201+
self.notes: list[str] = [] # Human-readable notes about any modifications made to the documents
201202

202203
def _convert_resource_to_stream_resource(self, doc: Union[Resource, StreamResource]) -> StreamResource:
203204
"""Make changes to and return a shallow copy of StreamRsource dictionary adhering to the new structure.
@@ -338,6 +339,8 @@ def stop(self, doc: RunStop):
338339
f"Cannot emit StreamDatum for {data_key} because the corresponding Datum document is missing."
339340
)
340341

342+
doc["_run_normalizer_notes"] = self.notes or [] # Add notes about modifications to the stop document
343+
341344
self.emit(DocumentNames.stop, doc)
342345

343346
def descriptor(self, doc: EventDescriptor):
@@ -351,14 +354,14 @@ def descriptor(self, doc: EventDescriptor):
351354
if f"_{name}" in doc["data_keys"].keys():
352355
raise ValueError(f"Cannot rename {name} to _{name} because it already exists")
353356
doc["data_keys"][f"_{name}"] = doc["data_keys"].pop(name)
354-
for obj_data_keys_list in doc["object_keys"].values():
357+
for obj_data_keys_list in doc.get("object_keys", {}).values():
355358
if name in obj_data_keys_list:
356359
obj_data_keys_list.remove(name)
357360
obj_data_keys_list.append(f"_{name}")
358361

359362
# Rename some fields (in-place) to match the current schema for the descriptor
360363
# Loop over all dictionaries that specify data_keys (both event data_keys or configuration data_keys)
361-
conf_data_keys = (obj["data_keys"].values() for obj in doc["configuration"].values())
364+
conf_data_keys = (obj["data_keys"].values() for obj in doc.get("configuration", {}).values())
362365
for data_keys_spec in itertools.chain(doc["data_keys"].values(), *conf_data_keys):
363366
# Determine numpy data type. From highest precedent to lowest:
364367
# 1. Try 'dtype_descr', optional, if present -- this is a structural dtype
@@ -376,8 +379,9 @@ def descriptor(self, doc: EventDescriptor):
376379
):
377380
data_keys_spec["dtype_numpy"] = dtype_numpy
378381

379-
# Ensure that all event data_keys have object_name assigned (for consistency)
380-
for obj_name, data_keys_list in doc["object_keys"].items():
382+
# Ensure that all event data_keys have object_name assigned, if known (for consistency)
383+
# If "object_keys" are not present, do not reconstruct them -- they are optional
384+
for obj_name, data_keys_list in doc.get("object_keys", {}).items():
381385
for key in data_keys_list:
382386
doc["data_keys"][key]["object_name"] = obj_name
383387

@@ -466,6 +470,7 @@ def datum(self, doc: Datum):
466470
if patch := self.patches.get("datum"):
467471
doc = patch(doc)
468472

473+
# Keep the Datum document in memory until it is referenced by an Event document
469474
self._datum_cache[doc["datum_id"]] = doc
470475

471476
def datum_page(self, doc: DatumPage):
@@ -613,9 +618,6 @@ def stop(self, doc: RunStop):
613618
self.root_node.update_metadata(metadata={"stop": doc, **dict(self.root_node.metadata)}, drop_revision=True)
614619

615620
def descriptor(self, doc: EventDescriptor):
616-
if self.root_node is None:
617-
raise RuntimeError("RunWriter is not properly initialized: no Start document has been recorded.")
618-
619621
desc_name = doc["name"] # Name of the descriptor/stream
620622
self.data_keys.update(doc.get("data_keys", {}))
621623

@@ -684,6 +686,7 @@ def get_sres_node(self, sres_uid: str, desc_uid: Optional[str] = None) -> tuple[
684686
if not desc_uid:
685687
raise RuntimeError("Descriptor uid must be specified to initialise a Stream Resource node")
686688

689+
# Define `full_data_key` as desc_name + _ + data_key to ensure uniqueness across streams
687690
sres_doc = self._stream_resource_cache[sres_uid]
688691
desc_node = self._desc_nodes[desc_uid]
689692
full_data_key = f"{desc_node.item['id']}_{sres_doc['data_key']}" # desc_name + data_key

bluesky-tiled-plugins/tests/examples/external_assets.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@
4343
13,
4444
17
4545
],
46+
"dims": [
47+
"time",
48+
"dim_x",
49+
"dim_y"
50+
],
4651
"external": "STREAM:",
4752
"object_name": "det-obj1"
4853
}

bluesky-tiled-plugins/tests/test_tiled_writer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,20 @@ def test_with_correct_sample_runs(client, batch_size, external_assets_folder, fn
382382
assert stream.read() is not None
383383

384384

385+
def test_dims_names(client, external_assets_folder):
386+
tw = TiledWriter(client)
387+
388+
for item in render_templated_documents("external_assets.json", external_assets_folder):
389+
if item["name"] == "start":
390+
uid = item["doc"]["uid"]
391+
tw(**item)
392+
393+
run = client[uid]
394+
395+
assert run["primary"]["det-key1"].structure().dims is None
396+
assert run["primary"]["det-key2"].structure().dims == ("time", "dim_x", "dim_y")
397+
398+
385399
@pytest.mark.parametrize(
386400
"batch_size, expected_patch_shapes, expected_patch_offsets",
387401
[(1, (1, 1, 1), (0, 1, 2)), (2, (2, 1), (0, 2)), (5, (3,), (0,))],

0 commit comments

Comments
 (0)