Skip to content

Commit 7505c6d

Browse files
authored
Merge pull request #195 from boutproject/staggered-integrate
Support staggered variables in BoutDataset.integrate_midpoints()
2 parents 3f8e522 + c8dc335 commit 7505c6d

File tree

4 files changed

+93
-58
lines changed

4 files changed

+93
-58
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ setup_requires =
2828
setuptools_scm_git_archive
2929
install_requires =
3030
xarray>=0.17.0
31-
boutdata>=0.1.2
31+
boutdata>=0.1.4
3232
dask[array]>=2.10.0
3333
natsort>=5.5.0
3434
matplotlib>=3.1.1,!=3.3.0,!=3.3.1,!=3.3.2

xbout/boutdataset.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,12 @@ def integrate_midpoints(self, variable, *, dims=None):
305305
"""
306306
ds = self.data
307307

308+
if isinstance(variable, str):
309+
variable = ds[variable]
310+
311+
location = variable.cell_location
312+
suffix = "" if location == "CELL_CENTRE" else f"_{location}"
313+
308314
tcoord = ds.metadata["bout_tdim"]
309315
xcoord = ds.metadata["bout_xdim"]
310316
ycoord = ds.metadata["bout_ydim"]
@@ -331,55 +337,61 @@ def integrate_midpoints(self, variable, *, dims=None):
331337
elif isinstance(dims, str):
332338
dims = [dims]
333339

334-
dx = ds["dx"]
335-
dy = ds["dy"]
340+
dx = ds[f"dx{suffix}"]
341+
dy = ds[f"dy{suffix}"]
336342
dz = ds.metadata["dz"]
337343

338344
# Work out the spatial volume element
339345
if xcoord in dims and ycoord in dims and zcoord in dims:
340346
# Volume integral, use the 3d Jacobian "J"
341-
spatial_volume_element = ds["J"] * dx * dy * dz
347+
spatial_volume_element = ds[f"J{suffix}"] * dx * dy * dz
342348
elif xcoord in dims and ycoord in dims:
343349
# 2d integral on poloidal planes
344-
if ds[variable].direction_y == "Standard":
350+
if variable.direction_y == "Standard":
345351
# Need to use a metric constructed from basis vectors within the
346352
# poloidal plane, so use 'reciprocal basis vectors' Grad(x^i)
347353
# J = 1/sqrt(det(g_2d))
348354
# det(g_2d) = g11*g22 - g12**2
349-
g = ds["g11"] * ds["g22"] - ds["g12"] ** 2
355+
g = ds[f"g11{suffix}"] * ds[f"g22{suffix}"] - ds[f"g12{suffix}"] ** 2
350356
J = 1.0 / np.sqrt(g)
351-
elif ds[variable].direction_y == "Aligned":
357+
elif variable.direction_y == "Aligned":
352358
# Need to work out area element from metric coefficients. See book by
353359
# D'haeseleer, Hitchon, Callen and Shohet eq. (2.5.51).
354360
# Need to use a metric constructed from basis vectors within the
355361
# field-aligned x-y plane, so use 'tangent basis vectors' e_i
356362
# J = sqrt(g_11*g_22 - g_12**2)
357-
J = np.sqrt(ds["g_11"] * ds["g_22"] - ds["g_12"] ** 2)
363+
J = np.sqrt(
364+
ds[f"g_11{suffix}"] * ds[f"g_22{suffix}"] - ds[f"g_12{suffix}"] ** 2
365+
)
358366
spatial_volume_element = J * dx * dy
359367
elif xcoord in dims and zcoord in dims:
360368
# 2d integral on toroidal planes
361369
# Need to work out area element from metric coefficients. See book by
362370
# D'haeseleer, Hitchon, Callen and Shohet eq. (2.5.51)
363371
# J = sqrt(g_11*g_33 - g_13**2)
364-
J = np.sqrt(ds["g_11"] * ds["g_33"] - ds["g_13"] ** 2)
372+
J = np.sqrt(
373+
ds[f"g_11{suffix}"] * ds[f"g_33{suffix}"] - ds[f"g_13{suffix}"] ** 2
374+
)
365375
spatial_volume_element = J * dx * dz
366376
elif ycoord in dims and zcoord in dims:
367377
# 2d integral on flux-surfaces
368378
# Need to work out area element from metric coefficients. See book by
369379
# D'haeseleer, Hitchon, Callen and Shohet eq. (2.5.51)
370380
# J = sqrt(g_22*g_33 - g_23**2)
371-
J = np.sqrt(ds["g_22"] * ds["g_33"] - ds["g_23"] ** 2)
381+
J = np.sqrt(
382+
ds[f"g_22{suffix}"] * ds[f"g_33{suffix}"] - ds[f"g_23{suffix}"] ** 2
383+
)
372384
spatial_volume_element = J * dy * dz
373385
elif xcoord in dims:
374-
if ds[variable].direction_y == "Aligned":
386+
if variable.direction_y == "Aligned":
375387
raise ValueError(
376388
"Variable is field-aligned, but radial integral along coordinate "
377389
"line in globally field-aligned coordinates not supported"
378390
)
379391
# 1d radial integral, line element is sqrt(g_11)*dx
380-
spatial_volume_element = np.sqrt(ds["g_11"]) * dx
392+
spatial_volume_element = np.sqrt(ds[f"g_11{suffix}"]) * dx
381393
elif ycoord in dims:
382-
if ds[variable].direction_y == "Standard":
394+
if variable.direction_y == "Standard":
383395
# Poloidal integral, line element is e_y projected onto a unit vector in
384396
# the poloidal direction. e_z is in the toroidal direction and Grad(x)
385397
# is orthogonal to flux surfaces, so their cross product is in the
@@ -398,34 +410,35 @@ def integrate_midpoints(self, variable, *, dims=None):
398410
# For 'orthogonal' coordinates (radial and poloidal directions are
399411
# orthogonal) this is equal to 1/sqrt(g22)
400412
spatial_volume_element = (
401-
(ds["g_22"] * ds["g_33"] - ds["g_23"] ** 2)
402-
/ (ds["J"] * np.sqrt(ds["g11"] * ds["g_33"]))
413+
(
414+
ds[f"g_22{suffix}"] * ds[f"g_33{suffix}"]
415+
- ds[f"g_23{suffix}"] ** 2
416+
)
417+
/ (
418+
ds[f"J{suffix}"]
419+
* np.sqrt(ds[f"g11{suffix}"] * ds[f"g_33{suffix}"])
420+
)
403421
* dy
404422
)
405-
elif ds[variable].direction_y == "Aligned":
423+
elif variable.direction_y == "Aligned":
406424
# Parallel integral, line element is sqrt(g_22)*dy
407-
spatial_volume_element = np.sqrt(ds["g_22"]) * dy
425+
spatial_volume_element = np.sqrt(ds[f"g_22{suffix}"]) * dy
408426
elif zcoord in dims:
409427
# Toroidal integral, line element is sqrt(g_33)*dz
410-
spatial_volume_element = np.sqrt(ds["g_33"]) * dz
428+
spatial_volume_element = np.sqrt(ds[f"g_33{suffix}"]) * dz
411429
else:
412430
# No spatial integral
413431
spatial_volume_element = 1.0
414432

415-
if isinstance(variable, xr.DataArray):
416-
integrand = variable
417-
else:
418-
integrand = ds[variable]
419-
420433
spatial_dims = set(dims) - set([tcoord])
421434

422435
# Need to check if the variable being integrated is a Field2D, which does not
423436
# have a z-dimension to sum over. Other variables are OK because metric
424437
# coefficients, dx and dy all have both x- and y-dimensions so variable would be
425438
# broadcast to include them if necessary
426-
missing_z_sum = zcoord in dims and zcoord not in integrand.dims
439+
missing_z_sum = zcoord in dims and zcoord not in variable.dims
427440

428-
integrand = integrand * spatial_volume_element
441+
integrand = variable * spatial_volume_element
429442

430443
integral = integrand.sum(dim=spatial_dims)
431444

xbout/tests/test_boutdataset.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,11 +1304,17 @@ def test_integrate_midpoints_slab(self, bout_xyt_example_files):
13041304
rtol=1.4e-4,
13051305
)
13061306

1307-
def test_integrate_midpoints_salpha(self, bout_xyt_example_files):
1307+
@pytest.mark.parametrize(
1308+
"location", ["CELL_CENTRE", "CELL_XLOW", "CELL_YLOW", "CELL_ZLOW"]
1309+
)
1310+
def test_integrate_midpoints_salpha(self, bout_xyt_example_files, location):
13081311
# Create data
1312+
nx = 100
1313+
ny = 110
1314+
nz = 120
13091315
dataset_list = bout_xyt_example_files(
13101316
None,
1311-
lengths=(4, 100, 110, 120),
1317+
lengths=(4, nx, ny, nz),
13121318
nxpe=1,
13131319
nype=1,
13141320
nt=1,
@@ -1323,6 +1329,16 @@ def test_integrate_midpoints_salpha(self, bout_xyt_example_files):
13231329

13241330
# Integrate 1 so we just get volume, areas and lengths
13251331
ds["n"].values[:] = 1.0
1332+
ds["n"].attrs["cell_location"] = location
1333+
1334+
# remove boundary cells (don't want to integrate over those)
1335+
ds = ds.bout.remove_yboundaries()
1336+
if ds.metadata["keep_xboundaries"] and ds.metadata["MXG"] > 0:
1337+
mxg = ds.metadata["MXG"]
1338+
xslice = slice(mxg, -mxg)
1339+
else:
1340+
xslice = slice(None)
1341+
ds = ds.isel(x=xslice)
13261342

13271343
# Test geometry has major radius R and goes between minor radii a-Lr/2 and
13281344
# a+Lr/2
@@ -1331,18 +1347,14 @@ def test_integrate_midpoints_salpha(self, bout_xyt_example_files):
13311347
Lr = options.evaluate_scalar("mesh:Lr")
13321348
rinner = a - Lr / 2.0
13331349
router = a + Lr / 2.0
1350+
r = options.evaluate("mesh:r").squeeze()[xslice]
1351+
if location == "CELL_XLOW":
1352+
rinner = rinner - Lr / (2.0 * nx)
1353+
router = router - Lr / (2.0 * nx)
1354+
r = r - Lr / (2.0 * nx)
13341355
q = options.evaluate_scalar("mesh:q")
13351356
T_total = (ds.sizes["t"] - 1) * (ds["t"][1] - ds["t"][0]).values
13361357

1337-
# remove boundary cells (don't want to integrate over those)
1338-
ds = ds.bout.remove_yboundaries()
1339-
if ds.metadata["keep_xboundaries"] and ds.metadata["MXG"] > 0:
1340-
mxg = ds.metadata["MXG"]
1341-
xslice = slice(mxg, -mxg)
1342-
else:
1343-
xslice = slice(None)
1344-
ds = ds.isel(x=xslice)
1345-
13461358
# Volume of torus with circular cross-section of major radius R and minor radius
13471359
# a is 2*pi*R*pi*a^2
13481360
# https://en.wikipedia.org/wiki/Torus
@@ -1378,12 +1390,6 @@ def test_integrate_midpoints_salpha(self, bout_xyt_example_files):
13781390
# Area of torus with circular cross-section of major radius R and minor radius a
13791391
# is 2*pi*R*2*pi*a
13801392
# https://en.wikipedia.org/wiki/Torus
1381-
mxg = options._keys.get("MXG", 2)
1382-
if mxg == 0:
1383-
xslice = slice(None)
1384-
else:
1385-
xslice = slice(mxg, -mxg)
1386-
r = options.evaluate("mesh:r").squeeze()[xslice]
13871393
npt.assert_allclose(
13881394
ds.bout.integrate_midpoints("n", dims=["theta", "zeta"]),
13891395
(2.0 * np.pi * R * 2.0 * np.pi * r)[np.newaxis, :]
@@ -1418,6 +1424,8 @@ def test_integrate_midpoints_salpha(self, bout_xyt_example_files):
14181424
# x-z planes are 'conical frustrums', with area pi*(Rinner + Router)*Lr
14191425
# https://en.wikipedia.org/wiki/Frustum
14201426
theta = _1d_coord_from_spacing(ds["dy"], "theta").values
1427+
if location == "CELL_YLOW":
1428+
theta = theta - 2.0 * np.pi / (2.0 * ny)
14211429
Rinner = R + rinner * np.cos(theta)
14221430
Router = R + router * np.cos(theta)
14231431
npt.assert_allclose(
@@ -1549,17 +1557,21 @@ def func(theta):
15491557
)
15501558

15511559
# Toroidal lines have length 2*pi*Rxy
1560+
if location == "CELL_CENTRE":
1561+
R_2d = ds["R"]
1562+
else:
1563+
R_2d = ds[f"Rxy_{location}"]
15521564
npt.assert_allclose(
15531565
ds.bout.integrate_midpoints("n", dims=["zeta"]),
1554-
(2.0 * np.pi * ds["R"]).values[np.newaxis, :, :]
1566+
(2.0 * np.pi * R_2d).values[np.newaxis, :, :]
15551567
* np.ones(ds.sizes["t"])[:, np.newaxis, np.newaxis],
15561568
rtol=1.0e-5,
15571569
atol=0.0,
15581570
)
15591571
# Integrate in time too
15601572
npt.assert_allclose(
15611573
ds.bout.integrate_midpoints("n", dims=["t", "zeta"]),
1562-
T_total * (2.0 * np.pi * ds["R"]),
1574+
T_total * (2.0 * np.pi * R_2d),
15631575
rtol=1.0e-5,
15641576
atol=0.0,
15651577
)

xbout/tests/utils_for_tests.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,19 @@ def set_geometry_from_input_file(ds, name):
6868
"dx",
6969
"dy",
7070
]:
71-
# Need all arrays returned from options.evaluate() to be the right shape.
72-
# Recommend adding '0*x' or '0*y' in the input file expressions if the
73-
# expression would be 1d otherwise.
74-
ds[v] = ds[v].copy(
75-
data=np.broadcast_to(
76-
options.evaluate(f"mesh:{v}").squeeze(axis=2)[slices], shape_2d
71+
for location in ["CELL_CENTRE", "CELL_XLOW", "CELL_YLOW", "CELL_ZLOW"]:
72+
suffix = "" if location == "CELL_CENTRE" else f"_{location}"
73+
# Need all arrays returned from options.evaluate() to be the right shape.
74+
# Recommend adding '0*x' or '0*y' in the input file expressions if the
75+
# expression would be 1d otherwise.
76+
ds[v + suffix] = ds[v].copy(
77+
data=np.broadcast_to(
78+
options.evaluate(f"mesh:{v}", location=location).squeeze(axis=2)[
79+
slices
80+
],
81+
shape_2d,
82+
)
7783
)
78-
)
7984

8085
# Set dz as it would be calculated by BOUT++ (but don't support zmin, zmax or
8186
# zperiod here)
@@ -86,14 +91,19 @@ def set_geometry_from_input_file(ds, name):
8691

8792
# Add extra fields needed by "toroidal" geometry
8893
for v in ["Rxy", "Zxy", "psixy"]:
89-
# Need all arrays returned from options.evaluate() to be the right shape.
90-
# Recommend adding '0*x' or '0*y' in the input file expressions if the
91-
# expression would be 1d otherwise.
92-
ds[v] = ds["g11"].copy(
93-
data=np.broadcast_to(
94-
options.evaluate(f"mesh:{v}").squeeze(axis=2)[slices], shape_2d
94+
for location in ["CELL_CENTRE", "CELL_XLOW", "CELL_YLOW", "CELL_ZLOW"]:
95+
suffix = "" if location == "CELL_CENTRE" else f"_{location}"
96+
# Need all arrays returned from options.evaluate() to be the right shape.
97+
# Recommend adding '0*x' or '0*y' in the input file expressions if the
98+
# expression would be 1d otherwise.
99+
ds[v + suffix] = ds["g11"].copy(
100+
data=np.broadcast_to(
101+
options.evaluate(f"mesh:{v}", location=location).squeeze(axis=2)[
102+
slices
103+
],
104+
shape_2d,
105+
)
95106
)
96-
)
97107

98108
# Set fields that don't have to be in input files to NaN
99109
for v in ["G1", "G2", "G3"]:

0 commit comments

Comments
 (0)