Skip to content

Commit 0f7ad93

Browse files
Merge pull request #716 from zwicker-group/field_col_slice_project
Added `slice` and `project` methods to `FieldCollection`
2 parents 0b645e5 + 8391fd1 commit 0f7ad93

File tree

3 files changed

+112
-6
lines changed

3 files changed

+112
-6
lines changed

pde/fields/collection.py

Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,28 @@ def append(
589589
labels=_labels,
590590
)
591591

592+
def _apply_to_fields(
593+
self,
594+
func: Callable[[DataFieldBase], DataFieldBase],
595+
*,
596+
label: str | None = None,
597+
) -> FieldCollection:
598+
"""Apply function to every individual field.
599+
600+
Args:
601+
func (callable):
602+
Function applied to every field of this collections
603+
label (str, optional):
604+
Name of the returned collection. If omitted, the current label is used.
605+
606+
Returns:
607+
:class:`~pde.fields.collection.FieldCollection`: Modified fields
608+
"""
609+
if label is None:
610+
label = self.label
611+
fields = [func(fields) for fields in self.fields]
612+
return self.__class__(fields, label=label)
613+
592614
def _unary_operation(self: FieldCollection, op: Callable) -> FieldCollection:
593615
"""Perform an unary operation on this field collection.
594616
@@ -623,12 +645,11 @@ def interpolate_to_grid(
623645
Name of the returned field collection
624646
625647
Returns:
626-
:class:`~pde.fields.coolection.FieldCollection`: Interpolated data
648+
:class:`~pde.fields.collection.FieldCollection`: Interpolated data
627649
"""
628-
if label is None:
629-
label = self.label
630-
fields = [f.interpolate_to_grid(grid, fill=fill) for f in self.fields]
631-
return self.__class__(fields, label=label)
650+
return self._apply_to_fields(
651+
lambda f: f.interpolate_to_grid(grid, fill=fill), label=label
652+
)
632653

633654
def smooth(
634655
self,
@@ -651,7 +672,8 @@ def smooth(
651672
Name of the returned field
652673
653674
Returns:
654-
Field collection with smoothed data, stored at `out` if given.
675+
:class:`~pde.fields.collection.FieldCollection`:
676+
Smoothed data, stored at `out` if given.
655677
"""
656678
# allocate memory for storing output
657679
if out is None:
@@ -682,6 +704,63 @@ def magnitudes(self) -> np.ndarray:
682704
""":class:`~numpy.ndarray`: scalar magnitudes of all fields."""
683705
return np.array([field.magnitude for field in self]) # type: ignore
684706

707+
def project(
708+
self, axes: str | Sequence[str], *, label: str | None = None, **kwargs
709+
) -> FieldCollection:
710+
"""Project fields along given axes.
711+
712+
This is currently only implemented for scalar fields. If any field in the
713+
collection has higher rank, the entire process fails.
714+
715+
Args:
716+
axes (list of str):
717+
The names of the axes that are removed by the projection operation. The
718+
valid names for a given grid are the ones in the :attr:`GridBase.axes`
719+
attribute.
720+
label (str, optional):
721+
Name of the returned collection. If omitted, the current label is used.
722+
method (str):
723+
The projection method. This can be either 'integral' to integrate over
724+
the removed axes or 'average' to perform an average instead.
725+
726+
Returns:
727+
:class:`~pde.fields.collection.FieldCollection`:
728+
The projected data of all fields on a subgrid of the original grid.
729+
"""
730+
if not all(isinstance(f, ScalarField) for f in self):
731+
raise TypeError("All fields must be scalar fields to project data")
732+
return self._apply_to_fields(lambda f: f.project(axes, **kwargs), label=label) # type: ignore
733+
734+
def slice(
735+
self, position: dict[str, float], *, label: str | None = None, **kwargs
736+
) -> FieldCollection:
737+
"""Slice all fields at a given position.
738+
739+
This is currently only implemented for scalar fields. If any field in the
740+
collection has higher rank, the entire process fails.
741+
742+
Args:
743+
position (dict):
744+
Determines the location of the slice using a dictionary supplying
745+
coordinate values for a subset of axes. Axes not mentioned in the
746+
dictionary are retained and form the slice. For instance, in a 2d
747+
Cartesian grid, `position = {'x': 1}` slices along the y-direction at
748+
x=1. Additionally, the special positions 'low', 'mid', and 'high' are
749+
supported to reference relative positions along the axis.
750+
label (str, optional):
751+
Name of the returned collection. If omitted, the current label is used.
752+
method (str):
753+
The method used for slicing. Currently, we only support `nearest`, which
754+
takes data from cells defined on the grid.
755+
756+
Returns:
757+
:class:`~pde.fields.collection.FieldCollection`:
758+
The projected data of all fields on a subgrid of the original grid.
759+
"""
760+
if not all(isinstance(f, ScalarField) for f in self):
761+
raise TypeError("All fields must be scalar fields to slice data")
762+
return self._apply_to_fields(lambda f: f.slice(position, **kwargs), label=label) # type: ignore
763+
685764
def get_line_data( # type: ignore
686765
self,
687766
index: int = 0,

pde/fields/scalar.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ def integral(self) -> Number:
275275
def project(
276276
self,
277277
axes: str | Sequence[str],
278+
*,
278279
method: Literal["integral", "average", "mean"] = "integral",
279280
label: str | None = None,
280281
) -> ScalarField:

tests/fields/test_field_collections.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,3 +371,29 @@ def test_merged_image_plotting(num):
371371
vmin=-1,
372372
vmax=[1, 2, 3, 4],
373373
)
374+
375+
376+
def test_collection_slice_project():
377+
"""Test slicing and projection of field data."""
378+
grid = UnitGrid([8, 16])
379+
fc = FieldCollection.scalar_random_uniform(2, grid)
380+
381+
f2 = fc.slice({"x": 0.5}, label="sliced")
382+
np.testing.assert_allclose(f2[0].data, fc[0].data[0, :])
383+
assert f2.label == "sliced"
384+
385+
f3 = fc.project("x", label="projected")
386+
np.testing.assert_allclose(f3[0].data, fc[0].data.sum(axis=0))
387+
assert f3.label == "projected"
388+
389+
390+
def test_collection_slice_project_wrong_type():
391+
"""Test slicing and projection of field data."""
392+
grid = UnitGrid([8, 16])
393+
fc = FieldCollection([VectorField(grid)])
394+
395+
with pytest.raises(TypeError):
396+
fc.slice({"x": 0.5})
397+
398+
with pytest.raises(TypeError):
399+
fc.project("x")

0 commit comments

Comments
 (0)