Skip to content

Commit 03ab912

Browse files
buggy update
1 parent 7613328 commit 03ab912

24 files changed

+275
-112
lines changed

pina/domain/base_domain.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Module for the Base Domain class."""
22

3+
from copy import deepcopy
34
from .domain_interface import DomainInterface
45
from ..utils import check_consistency, check_positive_integer
56

@@ -89,10 +90,12 @@ def update(self, domain):
8990
Each new label introduces a new dimension. Only domains of the same type
9091
can be used for update.
9192
92-
:param DomainInterface domain: The domain whose labels are to be merged
93+
:param BaseDomain domain: The domain whose labels are to be merged
9394
into the current one.
9495
:raises TypeError: If the provided domain is not of the same type as
9596
the current one.
97+
:return: A new domain instance with the merged labels.
98+
:rtype: BaseDomain
9699
"""
97100
# Raise an error if the domain types do not match
98101
if not isinstance(domain, type(self)):
@@ -102,8 +105,11 @@ def update(self, domain):
102105
)
103106

104107
# Update fixed and ranged variables
105-
self._fixed.update(domain._fixed)
106-
self._range.update(domain._range)
108+
updated = deepcopy(self)
109+
updated._fixed.update(domain._fixed)
110+
updated._range.update(domain._range)
111+
112+
return updated
107113

108114
def _validate_sampling(self, n, mode, variables):
109115
"""
@@ -164,7 +170,7 @@ def sample_modes(self, values):
164170
Setter for the ``sample_modes`` property.
165171
166172
:param values: The sampling modes to be set.
167-
:type values: str | list[str]
173+
:type values: str | list[str] | tuple[str]
168174
:raises ValueError: Invalid sampling mode.
169175
"""
170176
# Ensure values is a list

pina/domain/base_operation.py

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Module for the Base Operation class."""
22

3+
from copy import deepcopy
34
from .operation_interface import OperationInterface
45
from .base_domain import BaseDomain
56
from ..utils import check_consistency, check_positive_integer
@@ -31,26 +32,7 @@ def __init__(self, geometries):
3132
:raises NotImplementedError: If the dimensions of the geometries are not
3233
consistent.
3334
"""
34-
# Check geometries are list or tuple
35-
if not isinstance(geometries, (list, tuple)):
36-
raise TypeError(
37-
"geometries must be either a list or a tuple of BaseDomain."
38-
)
39-
40-
# Check consistency
41-
check_consistency(geometries, (BaseDomain, BaseOperation))
42-
43-
# Check geometries
44-
for geometry in geometries:
45-
if geometry.variables != geometries[0].variables:
46-
raise NotImplementedError(
47-
f"The {self.__class__.__name__} of geometries living in "
48-
"different ambient spaces is not well-defined. "
49-
"All geometries must share the same dimensions and labels."
50-
)
51-
52-
# Initialization
53-
self._geometries = geometries
35+
self.geometries = geometries
5436

5537
def _validate_sampling(self, n, mode, variables):
5638
"""
@@ -95,18 +77,40 @@ def _validate_sampling(self, n, mode, variables):
9577

9678
return sorted(variables)
9779

98-
def update(self, _):
80+
def update(self, domain):
9981
"""
10082
Update the domain resulting from the operation.
10183
102-
:raises NotImplementedError: The :meth:`update` method is not
103-
implemented for operation domains. Please update the individual
104-
domains instead.
105-
"""
106-
raise NotImplementedError(
107-
"The update method is not implemented for operation domains. "
108-
"Please update the individual domains instead."
109-
)
84+
:param DomainInterface domain: The domain whose labels are to be merged
85+
into the current one.
86+
:raises NotImplementedError: If the geometries involved in the operation
87+
are of different types.
88+
:raises TypeError: If the passed domain is not of the same type of all
89+
the geometries involved in the operation.
90+
:return: A new domain instance with the merged labels.
91+
:rtype: BaseOperation
92+
"""
93+
# Check all geometries are of the same type
94+
domain_type = type(self.geometries[0])
95+
if not all(isinstance(g, domain_type) for g in self.geometries):
96+
raise NotImplementedError(
97+
f"The {self.__class__.__name__} of geometries of different"
98+
" types does not support the update operation. All geometries"
99+
" must be of the same type."
100+
)
101+
102+
# Check domain type consistency
103+
if not isinstance(domain, domain_type):
104+
raise TypeError(
105+
f"Cannot update the {self.__class__.__name__} of domains of"
106+
f" type {domain_type} with domain of type {type(domain)}."
107+
)
108+
109+
# Update each geometry
110+
updated = deepcopy(self)
111+
updated.geometries = [geom.update(domain) for geom in self.geometries]
112+
113+
return updated
110114

111115
def partial(self):
112116
"""
@@ -167,3 +171,36 @@ def geometries(self):
167171
:rtype: list[BaseDomain]
168172
"""
169173
return self._geometries
174+
175+
@geometries.setter
176+
def geometries(self, values):
177+
"""
178+
Setter for the ``geometries`` property.
179+
180+
:param values: The geometries to be set.
181+
:type values: list[BaseDomain] | tuple[BaseDomain]
182+
:raises TypeError: If values is neither a list nor a tuple.
183+
:raises ValueError: If values elements are not instances of
184+
:class:`~pina.domain.base_domain.BaseDomain`.
185+
:raises NotImplementedError: If the dimensions of the geometries are not
186+
consistent.
187+
"""
188+
# Check geometries are list or tuple
189+
if not isinstance(values, (list, tuple)):
190+
raise TypeError(
191+
"geometries must be either a list or a tuple of BaseDomain."
192+
)
193+
194+
# Check consistency
195+
check_consistency(values, (BaseDomain, BaseOperation))
196+
197+
# Check geometries
198+
for v in values:
199+
if v.variables != values[0].variables:
200+
raise NotImplementedError(
201+
f"The {self.__class__.__name__} of geometries living in "
202+
"different ambient spaces is not well-defined. "
203+
"All geometries must share the same dimensions and labels."
204+
)
205+
206+
self._geometries = values

pina/domain/domain_interface.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def update(self, domain):
2929
3030
:param BaseDomain domain: The domain whose labels are to be merged into
3131
the current one.
32+
:return: A new domain instance with the merged labels.
33+
:rtype: DomainInterface
3234
"""
3335

3436
@abstractmethod
@@ -69,7 +71,7 @@ def sample_modes(self, values):
6971
Setter for the :attr:`sample_modes` property.
7072
7173
:param values: Sampling modes to be set.
72-
:type values: str | list[str]
74+
:type values: str | list[str] | tuple[str]
7375
"""
7476

7577
@property

pina/domain/ellipsoid_domain.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,13 @@ def update(self, domain):
128128
into the current one.
129129
:raises TypeError: If the provided domain is not of an instance of
130130
:class:`EllipsoidDomain`.
131+
:return: A new domain instance with the merged labels.
132+
:rtype: EllipsoidDomain
131133
"""
132-
super().update(domain)
133-
self._compute_center_axes()
134+
updated = super().update(domain)
135+
updated._compute_center_axes()
136+
137+
return updated
134138

135139
def sample(self, n, mode="random", variables="all"):
136140
"""

pina/domain/operation_interface.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,13 @@ def geometries(self):
1818
:return: The list of domains on which to perform the set operation.
1919
:rtype: list[BaseDomain]
2020
"""
21+
22+
@geometries.setter
23+
@abstractmethod
24+
def geometries(self, values):
25+
"""
26+
Setter for the ``geometries`` property.
27+
28+
:param values: The geometries to be set.
29+
:type values: list[BaseDomain] | tuple[BaseDomain]
30+
"""

pina/domain/simplex_domain.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def update(self, domain):
113113
:param SimplexDomain domain: The domain whose vertices are to be set
114114
into the current one.
115115
:raises TypeError: If the domain is not a :class:`SimplexDomain` object.
116+
:return: A new domain instance with the merged labels.
117+
:rtype: SimplexDomain
116118
"""
117119
# Raise an error if the domain types do not match
118120
if not isinstance(domain, type(self)):
@@ -122,7 +124,10 @@ def update(self, domain):
122124
)
123125

124126
# Replace geometry
125-
self._vert_matrix = domain._vert_matrix
127+
updated = deepcopy(self)
128+
updated._vert_matrix = domain._vert_matrix
129+
130+
return updated
126131

127132
def sample(self, n, mode="random", variables="all"):
128133
"""

pina/problem/abstract_problem.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import warnings
55
from copy import deepcopy
66
from ..utils import check_consistency
7-
from ..domain import DomainInterface, CartesianDomain
7+
from ..domain import DomainInterface, CartesianDomain, OperationInterface
88
from ..condition.domain_equation_condition import DomainEquationCondition
99
from ..label_tensor import LabelTensor
1010
from ..utils import merge_tensors, custom_warning_format
@@ -288,7 +288,7 @@ def _apply_custom_discretization(self, sample_rules, domains):
288288
"the input variables."
289289
)
290290
for domain in domains:
291-
if not isinstance(self.domains[domain], CartesianDomain):
291+
if not isinstance(self.domains[domain], (CartesianDomain, OperationInterface)):
292292
raise RuntimeError(
293293
"Custom discretisation can be applied only on Cartesian "
294294
"domains"

pina/problem/zoo/advection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ class AdvectionProblem(SpatialProblem, TimeDependentProblem):
4141
temporal_domain = CartesianDomain({"t": [0, 1]})
4242

4343
domains = {
44-
"D": CartesianDomain({"x": [0, 2 * torch.pi], "t": [0, 1]}),
45-
"t0": CartesianDomain({"x": [0, 2 * torch.pi], "t": 0.0}),
44+
"D": spatial_domain.update(temporal_domain),
45+
"t0": spatial_domain.update(CartesianDomain({"t": 0})),
4646
}
4747

4848
conditions = {

pina/problem/zoo/allen_cahn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ class AllenCahnProblem(TimeDependentProblem, SpatialProblem):
4545
temporal_domain = CartesianDomain({"t": [0, 1]})
4646

4747
domains = {
48-
"D": CartesianDomain({"x": [-1, 1], "t": [0, 1]}),
49-
"t0": CartesianDomain({"x": [-1, 1], "t": 0.0}),
48+
"D": spatial_domain.update(temporal_domain),
49+
"t0": spatial_domain.update(CartesianDomain({"t": 0})),
5050
}
5151

5252
conditions = {

pina/problem/zoo/diffusion_reaction.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,13 @@ class DiffusionReactionProblem(TimeDependentProblem, SpatialProblem):
4747
temporal_domain = CartesianDomain({"t": [0, 1]})
4848

4949
domains = {
50-
"D": CartesianDomain({"x": [-torch.pi, torch.pi], "t": [0, 1]}),
51-
"g1": CartesianDomain({"x": -torch.pi, "t": [0, 1]}),
52-
"g2": CartesianDomain({"x": torch.pi, "t": [0, 1]}),
53-
"t0": CartesianDomain({"x": [-torch.pi, torch.pi], "t": 0.0}),
50+
"D": spatial_domain.update(temporal_domain),
51+
"boundary": spatial_domain.partial().update(temporal_domain),
52+
"t0": spatial_domain.update(CartesianDomain({"t": 0})),
5453
}
5554

5655
conditions = {
57-
"g1": Condition(domain="g1", equation=FixedValue(0.0)),
58-
"g2": Condition(domain="g2", equation=FixedValue(0.0)),
56+
"boundary": Condition(domain="boundary", equation=FixedValue(0.0)),
5957
"t0": Condition(domain="t0", equation=Equation(initial_condition)),
6058
}
6159

0 commit comments

Comments
 (0)