Skip to content

Commit 0f7514f

Browse files
committed
Minor fixes for better scalar behavior
hg hash: 016f609813fe95864d10f2a0ad38f6fca5b23c4f
1 parent 54e1964 commit 0f7514f

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

dedalus/core/field.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,14 +550,17 @@ def as_ncc_operator(self, frozen_arg_basis_meta, cutoff, max_terms, cacheid=None
550550
if basis.separable:
551551
if not self.meta[basis.name]['constant']:
552552
raise ValueError("{} is non-constant along separable direction '{}'.".format(self, basis.name))
553+
# Scatter transverse-constant coefficients
553554
basis = domain.bases[-1]
554555
coeffs = np.zeros(basis.coeff_size, dtype=basis.coeff_dtype)
555-
# Scatter transverse-constant coefficients
556556
self.require_coeff_space()
557557
if domain.dist.rank == 0:
558558
select = (0,) * (domain.dim - 1)
559559
np.copyto(coeffs, self.data[select])
560560
domain.dist.comm_cart.Bcast(coeffs, root=0)
561+
# Revert to scalar behavior for constants
562+
if self.meta[-1]['constant']:
563+
return coeffs[0]
561564
# Build matrix
562565
ncc_basis_meta = self.meta[-1]
563566
arg_basis_meta = dict(frozen_arg_basis_meta)

dedalus/core/operators.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,12 @@ class TimeDerivative(LinearOperator, FutureField):
10801080

10811081
name = 'dt'
10821082

1083+
def __new__(cls, arg0, **kw):
1084+
if not isinstance(arg0, Operand):
1085+
return 0
1086+
else:
1087+
return object.__new__(cls)
1088+
10831089
@property
10841090
def base(self):
10851091
return TimeDerivative

dedalus/tests/test_ivp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@ def test_heat_1d_periodic(benchmark, x_basis_class, Nx, timestepper, dtype):
3030
x = domain.grid(0)
3131
F['g'] = -np.sin(x)
3232
# Problem
33-
problem = de.IVP(domain, variables=['u'])
33+
problem = de.IVP(domain, variables=['u','ux'])
3434
problem.meta['u']['x']['parity'] = -1
3535
problem.parameters['F'] = F
36-
problem.add_equation("-dt(u) + dx(dx(u)) = F")
36+
problem.add_equation("-dt(u) + dx(ux) = F")
37+
problem.add_equation("ux - dx(u) = 0")
3738
# Solver
3839
solver = problem.build_solver(timestepper)
3940
dt = 1e-5

0 commit comments

Comments
 (0)