Skip to content

Commit c42b620

Browse files
committed
Initial version of the PR
Updated names: Frames -> Dimensions SnakedFrames -> SnakedDimensions Added new ScanSlice class to hold information about the scan Added duration field to Dimensions and SnakedDimensions class Updated consume method on the Path class to return a ScanSlice Updated Specs to handle new changes
1 parent 9caa78d commit c42b620

File tree

2 files changed

+103
-29
lines changed

2 files changed

+103
-29
lines changed

src/scanspec/core.py

Lines changed: 78 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,30 @@ def if_instance_do(x: C, cls: type[C], func: Callable[[C], T]) -> T:
289289
AxesPoints = dict[Axis, npt.NDArray[np.float64]]
290290

291291

292+
# @dataclass
293+
class ScanSlice(Generic[Axis]):
294+
"""Generalization of the Dimensions class.
295+
296+
Only holds information but no methods to handle it.
297+
"""
298+
299+
def __init__(
300+
self,
301+
axes: list[Axis],
302+
midpoints: AxesPoints[Axis] | None = None,
303+
lower: AxesPoints[Axis] | None = None,
304+
upper: AxesPoints[Axis] | None = None,
305+
gap: GapArray | None = None,
306+
duration: DurationArray | None = None,
307+
):
308+
self.axes = axes
309+
self.midpoints = midpoints
310+
self.lower = lower
311+
self.upper = upper
312+
self.gap = gap
313+
self.duration = duration
314+
315+
292316
class Dimension(Generic[Axis]):
293317
"""Represents a series of scan frames along a number of axes.
294318
@@ -319,22 +343,24 @@ class Dimension(Generic[Axis]):
319343

320344
def __init__(
321345
self,
322-
midpoints: AxesPoints[Axis],
346+
midpoints: AxesPoints[Axis] | None = None,
323347
lower: AxesPoints[Axis] | None = None,
324348
upper: AxesPoints[Axis] | None = None,
325349
gap: GapArray | None = None,
350+
duration: DurationArray | None = None,
326351
):
327352
#: The midpoints of scan frames for each axis
328353
self.midpoints = midpoints
329354
#: The lower bounds of each scan frame in each axis for fly-scanning
330355
self.lower = lower or midpoints
331356
#: The upper bounds of each scan frame in each axis for fly-scanning
332357
self.upper = upper or midpoints
358+
self.duration = duration
333359
if gap is not None:
334360
#: Whether there is a gap between this frame and the previous. First
335361
#: element is whether there is a gap between the last frame and the first
336362
self.gap = gap
337-
else:
363+
elif gap is None and self.upper is not None and self.lower is not None:
338364
# Need to calculate gap as not passed one
339365
# We have a gap if upper[i] != lower[i+1] for any axes
340366
axes_gap = [
@@ -344,31 +370,38 @@ def __init__(
344370
)
345371
]
346372
self.gap = np.logical_or.reduce(axes_gap)
373+
else:
374+
self.gap = GapArray(0)
347375
# Check all axes and ordering are the same
348-
assert list(self.midpoints) == list(self.lower) == list(self.upper), (
349-
f"Mismatching axes "
350-
f"{list(self.midpoints)} != {list(self.lower)} != {list(self.upper)}"
351-
)
352-
# Check all lengths are the same
353-
lengths = {
354-
len(arr)
355-
for d in (self.midpoints, self.lower, self.upper)
356-
for arr in d.values()
357-
}
358-
lengths.add(len(self.gap))
359-
assert len(lengths) <= 1, f"Mismatching lengths {list(lengths)}"
376+
if (
377+
self.midpoints is not None
378+
and self.lower is not None
379+
and self.upper is not None
380+
):
381+
assert list(self.midpoints) == list(self.lower) == list(self.upper), (
382+
f"Mismatching axes "
383+
f"{list(self.midpoints)} != {list(self.lower)} != {list(self.upper)}"
384+
)
385+
# Check all lengths are the same
386+
lengths = {
387+
len(arr)
388+
for d in (self.midpoints, self.lower, self.upper)
389+
for arr in d.values()
390+
}
391+
lengths.add(len(self.gap))
392+
assert len(lengths) <= 1, f"Mismatching lengths {list(lengths)}"
360393

361394
def axes(self) -> list[Axis]:
362395
"""The axes which will move during the scan.
363396
364397
These will be present in `midpoints`, `lower` and `upper`.
365398
"""
366-
return list(self.midpoints.keys())
399+
return list(self.midpoints.keys()) if self.midpoints is not None else []
367400

368401
def __len__(self) -> int:
369402
"""The number of frames in this section of the scan."""
370403
# All axespoints arrays are same length, pick the first one
371-
return len(self.gap)
404+
return len(self.gap) if self.gap is not None else 0
372405

373406
def extract(
374407
self, indices: npt.NDArray[np.signedinteger[Any]], calculate_gap: bool = True
@@ -480,6 +513,18 @@ def _merge_frames(
480513
upper=dict_merge([fs.upper for fs in stack])
481514
if any(fs.midpoints is not fs.upper for fs in stack)
482515
else None,
516+
duration=stack[0].duration,
517+
)
518+
519+
520+
def Dimension2Slice(dimension: Dimension[Axis]):
521+
return ScanSlice(
522+
axes=dimension.axes(),
523+
midpoints=dimension.midpoints,
524+
upper=dimension.upper,
525+
lower=dimension.lower,
526+
gap=dimension.gap,
527+
duration=dimension.duration,
483528
)
484529

485530

@@ -488,12 +533,15 @@ class SnakedDimension(Dimension[Axis]):
488533

489534
def __init__(
490535
self,
491-
midpoints: AxesPoints[Axis],
536+
midpoints: AxesPoints[Axis] | None = None,
492537
lower: AxesPoints[Axis] | None = None,
493538
upper: AxesPoints[Axis] | None = None,
494539
gap: GapArray | None = None,
540+
duration: DurationArray | None = None,
495541
):
496-
super().__init__(midpoints, lower=lower, upper=upper, gap=gap)
542+
super().__init__(
543+
midpoints, lower=lower, upper=upper, gap=gap, duration=duration
544+
)
497545
# Override first element of gap to be True, as subsequent runs
498546
# of snake scans are always joined end -> start
499547
self.gap[0] = False
@@ -645,8 +693,8 @@ def __init__(
645693
if num is not None and start + num < self.end_index:
646694
self.end_index = start + num
647695

648-
def consume(self, num: int | None = None) -> Dimension[Axis]:
649-
"""Consume at most num frames from the Path and return as a Dimension object.
696+
def consume(self, num: int | None = None) -> ScanSlice[Axis]:
697+
"""Consume at most num frames from the Path and return as a l object.
650698
651699
>>> fx = SnakedDimension({"x": np.array([1, 2])})
652700
>>> fy = Dimension({"y": np.array([3, 4])})
@@ -665,7 +713,11 @@ def consume(self, num: int | None = None) -> Dimension[Axis]:
665713
indices = np.arange(self.index, end_index)
666714
self.index = end_index
667715
stack: Dimension[Axis] = Dimension(
668-
{}, {}, {}, np.zeros(indices.shape, dtype=np.bool_)
716+
{},
717+
{},
718+
{},
719+
np.zeros(indices.shape, dtype=np.bool_),
720+
np.zeros(indices.shape, dtype=np.float64),
669721
)
670722
# Example numbers below from a 2x3x4 ZxYxX scan
671723
for i, frames in enumerate(self.stack):
@@ -691,7 +743,10 @@ def consume(self, num: int | None = None) -> Dimension[Axis]:
691743
sliced.gap &= in_gap
692744
# Zip it with the output Dimension object
693745
stack = stack.zip(sliced)
694-
return stack
746+
747+
test_ = Dimension2Slice(stack)
748+
749+
return test_
695750

696751
def __len__(self) -> int:
697752
"""Number of frames left in a scan, reduces when `consume` is called."""
@@ -739,4 +794,4 @@ def __iter__(self) -> Iterator[dict[Axis, float]]:
739794
path = Path(self.stack)
740795
while len(path):
741796
frames = path.consume(1)
742-
yield {a: frames.midpoints[a][0] for a in frames.axes()}
797+
yield {a: frames.midpoints[a][0] for a in frames.axes}

src/scanspec/specs.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Midpoints,
2222
OtherAxis,
2323
Path,
24+
ScanSlice,
2425
SnakedDimension,
2526
StrictConfig,
2627
discriminated_union_of_subclasses,
@@ -80,7 +81,7 @@ def calculate(
8081
"""
8182
raise NotImplementedError(self)
8283

83-
def frames(self) -> Dimension[Axis]:
84+
def frames(self) -> ScanSlice[Axis]:
8485
"""Expand all the scan `Dimension` and return them."""
8586
return Path(self.calculate()).consume()
8687

@@ -476,6 +477,26 @@ def _dimensions_from_indexes(
476477
return [dimension]
477478

478479

480+
@dataclass(config=StrictConfig)
481+
class Duration(Spec[Axis]):
482+
"""Special dimension used to define the array of durations for each frame.
483+
484+
.. example_spec::
485+
486+
from scanspec.specs import Duration
487+
488+
spec = Duration(1,10)
489+
"""
490+
491+
duration: float = Field(description="Duration of each frame")
492+
num: int = Field(ge=1, description="Number of frames to produce")
493+
494+
def calculate(self, bounds=True, nested=False) -> list[Dimension[Axis]]:
495+
return [
496+
Dimension(None, None, None, None, duration=np.full(self.num, self.duration))
497+
]
498+
499+
479500
@dataclass(config=StrictConfig)
480501
class Line(Spec[Axis]):
481502
"""Linearly spaced frames with start and stop as first and last midpoints.
@@ -747,15 +768,13 @@ def get_constant_duration(frames: list[Dimension[Any]]) -> float | None:
747768
None: otherwise
748769
749770
"""
750-
duration_frame = [
751-
f for f in frames if DURATION in f.axes() and len(f.midpoints[DURATION])
752-
]
753-
if len(duration_frame) != 1 or len(duration_frame[0]) < 1:
771+
duration_frame = [f.duration for f in frames if f.duration is not None]
772+
if len(duration_frame) != 1 or duration_frame[0].size < 1:
754773
# Either no frame has DURATION axis,
755774
# the frame with a DURATION axis has 0 points,
756775
# or multiple frames have DURATION axis
757776
return None
758-
durations = duration_frame[0].midpoints[DURATION]
777+
durations = duration_frame[0]
759778
first_duration = durations[0]
760779
if np.any(durations != first_duration):
761780
# Not all durations are the same

0 commit comments

Comments
 (0)