Skip to content

Commit

Permalink
update checking for grid_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
RHammond2 committed Dec 19, 2023
1 parent f2a0dd4 commit 053e1f2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
8 changes: 3 additions & 5 deletions floris/type_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,22 +193,20 @@ def validate_5DArray_shape(instance, attribute: Attribute, value: np.ndarray) ->
N wind directions x N wind speeds x N turbines x N grid points x N grid points.
"""
if not isinstance(value, np.ndarray):
print(type(value))
raise TypeError(f"`{attribute.name}` is not a valid NumPy array type.")

# Don't fail on the initialized empty array
if value.size == 0:
return

grid = instance.grid_resolution
shape = (instance.n_wind_directions, instance.n_wind_speeds, instance.n_turbines, grid, grid)
if value.shape != shape:
if value.shape != instance.grid_shape:
broadcast_shape = (
instance.n_wind_directions, instance.n_wind_speeds, instance.n_turbines, 1, 1
)
if value.shape != broadcast_shape:
raise ValueError(
f"`{attribute.name}` should have shape: {shape}; not shape: {value.shape}"
f"`{attribute.name}` should have shape: {instance.grid_shape}; not shape: "
f"{value.shape}"
)


Expand Down
12 changes: 11 additions & 1 deletion tests/type_dec_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,23 @@ class ArrayValidatorDemoClass(ValidateMixin):
n_wind_directions: int = field(default=3)
n_wind_speeds: int = field(default=4)
n_turbines: int = field(default=10)
grid_resolution: int = field(default=3)
grid_shape: tuple[int, int, int, int, int] = field(init=False)
three_dim_starts_as_one: NDArrayFloat = field(
factory=lambda: np.array([]), validator=validate_3DArray_shape
)
five_dimensions_provided: NDArrayFloat = array_5D_field
three_dimensions_provided: NDArrayFloat = array_3D_field
mixed_dimensions_provided: NDArrayFloat = array_mixed_dim_field

def set_grid(self, grid_resolution: int):
self.grid_shape = (
self.n_wind_directions,
self.n_wind_speeds,
self.n_turbines,
grid_resolution,
grid_resolution
)


def test_as_dict():
# Non-initialized attributes should not be exported
Expand Down Expand Up @@ -154,6 +163,7 @@ def test_array_validators():

# Check initialization works
demo = ArrayValidatorDemoClass()
demo.set_grid(3)
demo.validate()

# Check assignment with correct shape: 3 x 4 x 10 (x 3 x 3)
Expand Down

0 comments on commit 053e1f2

Please sign in to comment.