@@ -289,6 +289,30 @@ def if_instance_do(x: C, cls: type[C], func: Callable[[C], T]) -> T:
289289AxesPoints = dict [Axis , npt .NDArray [np .float64 ]]
290290
291291
292+ # @dataclass
293+ class ScanSlice (Generic [Axis ]):
294+ """Generalization of the Dimensions class.
295+
296+ Only holds information but no methods to handle it.
297+ """
298+
299+ def __init__ (
300+ self ,
301+ axes : list [Axis ],
302+ midpoints : AxesPoints [Axis ] | None = None ,
303+ lower : AxesPoints [Axis ] | None = None ,
304+ upper : AxesPoints [Axis ] | None = None ,
305+ gap : GapArray | None = None ,
306+ duration : DurationArray | None = None ,
307+ ):
308+ self .axes = axes
309+ self .midpoints = midpoints
310+ self .lower = lower
311+ self .upper = upper
312+ self .gap = gap
313+ self .duration = duration
314+
315+
292316class Dimension (Generic [Axis ]):
293317 """Represents a series of scan frames along a number of axes.
294318
@@ -319,22 +343,24 @@ class Dimension(Generic[Axis]):
319343
320344 def __init__ (
321345 self ,
322- midpoints : AxesPoints [Axis ],
346+ midpoints : AxesPoints [Axis ] | None = None ,
323347 lower : AxesPoints [Axis ] | None = None ,
324348 upper : AxesPoints [Axis ] | None = None ,
325349 gap : GapArray | None = None ,
350+ duration : DurationArray | None = None ,
326351 ):
327352 #: The midpoints of scan frames for each axis
328353 self .midpoints = midpoints
329354 #: The lower bounds of each scan frame in each axis for fly-scanning
330355 self .lower = lower or midpoints
331356 #: The upper bounds of each scan frame in each axis for fly-scanning
332357 self .upper = upper or midpoints
358+ self .duration = duration
333359 if gap is not None :
334360 #: Whether there is a gap between this frame and the previous. First
335361 #: element is whether there is a gap between the last frame and the first
336362 self .gap = gap
337- else :
363+ elif gap is None and self . upper is not None and self . lower is not None :
338364 # Need to calculate gap as not passed one
339365 # We have a gap if upper[i] != lower[i+1] for any axes
340366 axes_gap = [
@@ -344,31 +370,38 @@ def __init__(
344370 )
345371 ]
346372 self .gap = np .logical_or .reduce (axes_gap )
373+ else :
374+ self .gap = GapArray (0 )
347375 # Check all axes and ordering are the same
348- assert list (self .midpoints ) == list (self .lower ) == list (self .upper ), (
349- f"Mismatching axes "
350- f"{ list (self .midpoints )} != { list (self .lower )} != { list (self .upper )} "
351- )
352- # Check all lengths are the same
353- lengths = {
354- len (arr )
355- for d in (self .midpoints , self .lower , self .upper )
356- for arr in d .values ()
357- }
358- lengths .add (len (self .gap ))
359- assert len (lengths ) <= 1 , f"Mismatching lengths { list (lengths )} "
376+ if (
377+ self .midpoints is not None
378+ and self .lower is not None
379+ and self .upper is not None
380+ ):
381+ assert list (self .midpoints ) == list (self .lower ) == list (self .upper ), (
382+ f"Mismatching axes "
383+ f"{ list (self .midpoints )} != { list (self .lower )} != { list (self .upper )} "
384+ )
385+ # Check all lengths are the same
386+ lengths = {
387+ len (arr )
388+ for d in (self .midpoints , self .lower , self .upper )
389+ for arr in d .values ()
390+ }
391+ lengths .add (len (self .gap ))
392+ assert len (lengths ) <= 1 , f"Mismatching lengths { list (lengths )} "
360393
361394 def axes (self ) -> list [Axis ]:
362395 """The axes which will move during the scan.
363396
364397 These will be present in `midpoints`, `lower` and `upper`.
365398 """
366- return list (self .midpoints .keys ())
399+ return list (self .midpoints .keys ()) if self . midpoints is not None else []
367400
368401 def __len__ (self ) -> int :
369402 """The number of frames in this section of the scan."""
370403 # All axespoints arrays are same length, pick the first one
371- return len (self .gap )
404+ return len (self .gap ) if self . gap is not None else 0
372405
373406 def extract (
374407 self , indices : npt .NDArray [np .signedinteger [Any ]], calculate_gap : bool = True
@@ -480,6 +513,18 @@ def _merge_frames(
480513 upper = dict_merge ([fs .upper for fs in stack ])
481514 if any (fs .midpoints is not fs .upper for fs in stack )
482515 else None ,
516+ duration = stack [0 ].duration ,
517+ )
518+
519+
520+ def Dimension2Slice (dimension : Dimension [Axis ]):
521+ return ScanSlice (
522+ axes = dimension .axes (),
523+ midpoints = dimension .midpoints ,
524+ upper = dimension .upper ,
525+ lower = dimension .lower ,
526+ gap = dimension .gap ,
527+ duration = dimension .duration ,
483528 )
484529
485530
@@ -488,12 +533,15 @@ class SnakedDimension(Dimension[Axis]):
488533
489534 def __init__ (
490535 self ,
491- midpoints : AxesPoints [Axis ],
536+ midpoints : AxesPoints [Axis ] | None = None ,
492537 lower : AxesPoints [Axis ] | None = None ,
493538 upper : AxesPoints [Axis ] | None = None ,
494539 gap : GapArray | None = None ,
540+ duration : DurationArray | None = None ,
495541 ):
496- super ().__init__ (midpoints , lower = lower , upper = upper , gap = gap )
542+ super ().__init__ (
543+ midpoints , lower = lower , upper = upper , gap = gap , duration = duration
544+ )
497545 # Override first element of gap to be True, as subsequent runs
498546 # of snake scans are always joined end -> start
499547 self .gap [0 ] = False
@@ -645,8 +693,8 @@ def __init__(
645693 if num is not None and start + num < self .end_index :
646694 self .end_index = start + num
647695
648- def consume (self , num : int | None = None ) -> Dimension [Axis ]:
649- """Consume at most num frames from the Path and return as a Dimension object.
696+ def consume (self , num : int | None = None ) -> ScanSlice [Axis ]:
697+ """Consume at most num frames from the Path and return as a l object.
650698
651699 >>> fx = SnakedDimension({"x": np.array([1, 2])})
652700 >>> fy = Dimension({"y": np.array([3, 4])})
@@ -665,7 +713,11 @@ def consume(self, num: int | None = None) -> Dimension[Axis]:
665713 indices = np .arange (self .index , end_index )
666714 self .index = end_index
667715 stack : Dimension [Axis ] = Dimension (
668- {}, {}, {}, np .zeros (indices .shape , dtype = np .bool_ )
716+ {},
717+ {},
718+ {},
719+ np .zeros (indices .shape , dtype = np .bool_ ),
720+ np .zeros (indices .shape , dtype = np .float64 ),
669721 )
670722 # Example numbers below from a 2x3x4 ZxYxX scan
671723 for i , frames in enumerate (self .stack ):
@@ -691,7 +743,10 @@ def consume(self, num: int | None = None) -> Dimension[Axis]:
691743 sliced .gap &= in_gap
692744 # Zip it with the output Dimension object
693745 stack = stack .zip (sliced )
694- return stack
746+
747+ test_ = Dimension2Slice (stack )
748+
749+ return test_
695750
696751 def __len__ (self ) -> int :
697752 """Number of frames left in a scan, reduces when `consume` is called."""
@@ -739,4 +794,4 @@ def __iter__(self) -> Iterator[dict[Axis, float]]:
739794 path = Path (self .stack )
740795 while len (path ):
741796 frames = path .consume (1 )
742- yield {a : frames .midpoints [a ][0 ] for a in frames .axes () }
797+ yield {a : frames .midpoints [a ][0 ] for a in frames .axes }
0 commit comments