Skip to content

Commit 76533af

Browse files
authored
Merge pull request #226 from boutproject/fix-to_field_aligned-wrong-dim-order
Fix `to_field_aligned()`/`from_field_aligned()` for transposed arrays
2 parents ffda4b2 + 91833c6 commit 76533af

File tree

2 files changed

+58
-6
lines changed

2 files changed

+58
-6
lines changed

xbout/boutdataarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def _shift_z(self, zShift):
118118

119119
data_shifted_fft = data_fft * np.exp(phase.data)
120120

121-
data_shifted = fft.irfft(data_shifted_fft, n=nz)
121+
data_shifted = fft.irfft(data_shifted_fft, n=nz, axis=axis)
122122

123123
# Return a DataArray with the same attributes as self, but values from
124124
# data_shifted

xbout/tests/test_boutdataarray.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ def test_remove_yboundaries(
9898
pytest.param(9, marks=pytest.mark.long),
9999
],
100100
)
101-
def test_to_field_aligned(self, bout_xyt_example_files, nz):
101+
@pytest.mark.parametrize(
102+
"permute_dims", [False, pytest.param(True, marks=pytest.mark.long)]
103+
)
104+
def test_to_field_aligned(self, bout_xyt_example_files, nz, permute_dims):
102105
dataset_list = bout_xyt_example_files(
103106
None, lengths=(3, 3, 4, nz), nxpe=1, nype=1, nt=1
104107
)
@@ -126,8 +129,15 @@ def test_to_field_aligned(self, bout_xyt_example_files, nz):
126129
n[t, x, y, z] = 1000.0 * t + 100.0 * x + 10.0 * y + z
127130

128131
n.attrs["direction_y"] = "Standard"
132+
133+
if permute_dims:
134+
n = n.transpose("t", "zeta", "x", "theta").compute()
135+
129136
n_al = n.bout.to_field_aligned()
130137

138+
if permute_dims:
139+
n_al = n_al.transpose("t", "x", "theta", "zeta").compute()
140+
131141
assert n_al.direction_y == "Aligned"
132142

133143
for t in range(ds.sizes["t"]):
@@ -195,7 +205,10 @@ def test_to_field_aligned(self, bout_xyt_example_files, nz):
195205
atol=0.0,
196206
) # noqa: E501
197207

198-
def test_to_field_aligned_dask(self, bout_xyt_example_files):
208+
@pytest.mark.parametrize(
209+
"permute_dims", [False, pytest.param(True, marks=pytest.mark.long)]
210+
)
211+
def test_to_field_aligned_dask(self, bout_xyt_example_files, permute_dims):
199212

200213
nz = 6
201214

@@ -231,8 +244,15 @@ def test_to_field_aligned_dask(self, bout_xyt_example_files):
231244
assert isinstance(n.data, dask.array.Array)
232245

233246
n.attrs["direction_y"] = "Standard"
247+
248+
if permute_dims:
249+
n = n.transpose("t", "zeta", "x", "theta").compute()
250+
234251
n_al = n.bout.to_field_aligned()
235252

253+
if permute_dims:
254+
n_al = n_al.transpose("t", "x", "theta", "zeta").compute()
255+
236256
assert n_al.direction_y == "Aligned"
237257

238258
for t in range(ds.sizes["t"]):
@@ -309,7 +329,10 @@ def test_to_field_aligned_dask(self, bout_xyt_example_files):
309329
pytest.param(9, marks=pytest.mark.long),
310330
],
311331
)
312-
def test_from_field_aligned(self, bout_xyt_example_files, nz):
332+
@pytest.mark.parametrize(
333+
"permute_dims", [False, pytest.param(True, marks=pytest.mark.long)]
334+
)
335+
def test_from_field_aligned(self, bout_xyt_example_files, nz, permute_dims):
313336
dataset_list = bout_xyt_example_files(
314337
None, lengths=(3, 3, 4, nz), nxpe=1, nype=1, nt=1
315338
)
@@ -337,8 +360,15 @@ def test_from_field_aligned(self, bout_xyt_example_files, nz):
337360
n[t, x, y, z] = 1000.0 * t + 100.0 * x + 10.0 * y + z
338361

339362
n.attrs["direction_y"] = "Aligned"
363+
364+
if permute_dims:
365+
n = n.transpose("t", "zeta", "x", "theta").compute()
366+
340367
n_nal = n.bout.from_field_aligned()
341368

369+
if permute_dims:
370+
n_nal = n_nal.transpose("t", "x", "theta", "zeta").compute()
371+
342372
assert n_nal.direction_y == "Standard"
343373

344374
for t in range(ds.sizes["t"]):
@@ -407,7 +437,12 @@ def test_from_field_aligned(self, bout_xyt_example_files, nz):
407437
) # noqa: E501
408438

409439
@pytest.mark.parametrize("stag_location", ["CELL_XLOW", "CELL_YLOW", "CELL_ZLOW"])
410-
def test_to_field_aligned_staggered(self, bout_xyt_example_files, stag_location):
440+
@pytest.mark.parametrize(
441+
"permute_dims", [False, pytest.param(True, marks=pytest.mark.long)]
442+
)
443+
def test_to_field_aligned_staggered(
444+
self, bout_xyt_example_files, stag_location, permute_dims
445+
):
411446
dataset_list = bout_xyt_example_files(
412447
None, lengths=(3, 3, 4, 8), nxpe=1, nype=1, nt=1
413448
)
@@ -434,8 +469,14 @@ def test_to_field_aligned_staggered(self, bout_xyt_example_files, stag_location)
434469
for z in range(ds.sizes["zeta"]):
435470
n[t, x, y, z] = 1000.0 * t + 100.0 * x + 10.0 * y + z
436471

472+
if permute_dims:
473+
n = n.transpose("t", "zeta", "x", "theta").compute()
474+
437475
n_al = n.bout.to_field_aligned().copy(deep=True)
438476

477+
if permute_dims:
478+
n_al = n_al.transpose("t", "x", "theta", "zeta").compute()
479+
439480
assert n_al.direction_y == "Aligned"
440481

441482
# make 'n' staggered
@@ -459,7 +500,12 @@ def test_to_field_aligned_staggered(self, bout_xyt_example_files, stag_location)
459500
npt.assert_equal(n_stag_al.values, n_al.values)
460501

461502
@pytest.mark.parametrize("stag_location", ["CELL_XLOW", "CELL_YLOW", "CELL_ZLOW"])
462-
def test_from_field_aligned_staggered(self, bout_xyt_example_files, stag_location):
503+
@pytest.mark.parametrize(
504+
"permute_dims", [False, pytest.param(True, marks=pytest.mark.long)]
505+
)
506+
def test_from_field_aligned_staggered(
507+
self, bout_xyt_example_files, stag_location, permute_dims
508+
):
463509
dataset_list = bout_xyt_example_files(
464510
None, lengths=(3, 3, 4, 8), nxpe=1, nype=1, nt=1
465511
)
@@ -488,8 +534,14 @@ def test_from_field_aligned_staggered(self, bout_xyt_example_files, stag_locatio
488534
n.attrs["direction_y"] = "Aligned"
489535
ds["T"].attrs["direction_y"] = "Aligned"
490536

537+
if permute_dims:
538+
n = n.transpose("t", "zeta", "x", "theta").compute()
539+
491540
n_nal = n.bout.from_field_aligned().copy(deep=True)
492541

542+
if permute_dims:
543+
n_nal = n_nal.transpose("t", "x", "theta", "zeta").compute()
544+
493545
assert n_nal.direction_y == "Standard"
494546

495547
# make 'n' staggered

0 commit comments

Comments
 (0)