Skip to content

Commit cb73a6f

Browse files
Handle special case for identical boundary conditions in get_boundary_axis function and add corresponding test (#644)
* Handle special case for identical boundary conditions in get_boundary_axis function and add corresponding test * Refactor get_boundary_axis to handle special case for identical boundary conditions more explicitly
1 parent 258a6fd commit cb73a6f

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

pde/grids/boundaries/axis.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from __future__ import annotations
2020

21+
import collections
2122
from typing import TYPE_CHECKING, Union
2223

2324
import numpy as np
@@ -408,23 +409,34 @@ def get_boundary_axis(
408409
:class:`~pde.grids.boundaries.axis.BoundaryAxisBase`:
409410
Appropriate boundary condition for the axis
410411
"""
412+
# Handle special case where two identical conditions are given. In particular, this
413+
# covers the special case where `data == ("periodic", "periodic")` and similar
414+
# constructs. These are converted to `data == "periodic"`, so the next check can
415+
# catch them properly.
416+
if isinstance(data, collections.abc.Sequence) and data[0] == data[1]:
417+
data = data[0]
418+
411419
# handle special case describing potentially periodic boundary conditions
412420
if isinstance(data, str) and data.startswith("auto_periodic_"):
413421
data = "periodic" if grid.periodic[axis] else data[len("auto_periodic_") :]
414422

415423
# handle different types of data that specify boundary conditions
416424
if isinstance(data, BoundaryAxisBase):
417-
# boundary is already an the correct format
425+
# boundary is already a fully fledged instance
418426
bcs = data
419-
elif data == "periodic" or data == ("periodic", "periodic"):
427+
428+
elif data == "periodic" or (
429+
isinstance(data, dict) and data.get("type") == "periodic"
430+
):
420431
# initialize a periodic boundary condition
421432
bcs = BoundaryPeriodic(grid, axis)
422-
elif data == "anti-periodic" or data == ("anti-periodic", "anti-periodic"):
423-
# initialize a anti-periodic boundary condition
433+
434+
elif data == "anti-periodic" or (
435+
isinstance(data, dict) and data.get("type") == "anti-periodic"
436+
):
437+
# initialize an anti-periodic boundary condition
424438
bcs = BoundaryPeriodic(grid, axis, flip_sign=True)
425-
elif isinstance(data, dict) and data.get("type") == "periodic":
426-
# initialize a periodic boundary condition
427-
bcs = BoundaryPeriodic(grid, axis)
439+
428440
else:
429441
# initialize independent boundary conditions for the two sides
430442
bcs = BoundaryPair.from_data(grid, axis, data, rank=rank)

tests/grids/boundaries/test_axis_boundaries.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,7 @@ def test_get_axis_boundaries():
7171
else:
7272
assert not b.periodic
7373
assert len(list(b)) == 2
74+
75+
# check double setting
76+
c = get_boundary_axis(g, 0, (data, data))
77+
assert b == c

0 commit comments

Comments
 (0)