From 053e1f2ab93b1e0c6a401e021b15447a574be87e Mon Sep 17 00:00:00 2001 From: RHammond2 <13874373+RHammond2@users.noreply.github.com> Date: Mon, 18 Dec 2023 16:04:10 -0800 Subject: [PATCH] update checking for grid_shape --- floris/type_dec.py | 8 +++----- tests/type_dec_unit_test.py | 12 +++++++++++- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/floris/type_dec.py b/floris/type_dec.py index 0ef047bfe..5c8d3842d 100644 --- a/floris/type_dec.py +++ b/floris/type_dec.py @@ -193,22 +193,20 @@ def validate_5DArray_shape(instance, attribute: Attribute, value: np.ndarray) -> N wind directions x N wind speeds x N turbines x N grid points x N grid points. """ if not isinstance(value, np.ndarray): - print(type(value)) raise TypeError(f"`{attribute.name}` is not a valid NumPy array type.") # Don't fail on the initialized empty array if value.size == 0: return - grid = instance.grid_resolution - shape = (instance.n_wind_directions, instance.n_wind_speeds, instance.n_turbines, grid, grid) - if value.shape != shape: + if value.shape != instance.grid_shape: broadcast_shape = ( instance.n_wind_directions, instance.n_wind_speeds, instance.n_turbines, 1, 1 ) if value.shape != broadcast_shape: raise ValueError( - f"`{attribute.name}` should have shape: {shape}; not shape: {value.shape}" + f"`{attribute.name}` should have shape: {instance.grid_shape}; not shape: " + f"{value.shape}" ) diff --git a/tests/type_dec_unit_test.py b/tests/type_dec_unit_test.py index a4f09184f..ff21b7858 100644 --- a/tests/type_dec_unit_test.py +++ b/tests/type_dec_unit_test.py @@ -62,7 +62,7 @@ class ArrayValidatorDemoClass(ValidateMixin): n_wind_directions: int = field(default=3) n_wind_speeds: int = field(default=4) n_turbines: int = field(default=10) - grid_resolution: int = field(default=3) + grid_shape: tuple[int, int, int, int, int] = field(init=False) three_dim_starts_as_one: NDArrayFloat = field( factory=lambda: np.array([]), validator=validate_3DArray_shape ) @@ -70,6 +70,15 @@ class ArrayValidatorDemoClass(ValidateMixin): three_dimensions_provided: NDArrayFloat = array_3D_field mixed_dimensions_provided: NDArrayFloat = array_mixed_dim_field + def set_grid(self, grid_resolution: int): + self.grid_shape = ( + self.n_wind_directions, + self.n_wind_speeds, + self.n_turbines, + grid_resolution, + grid_resolution + ) + def test_as_dict(): # Non-initialized attributes should not be exported @@ -154,6 +163,7 @@ def test_array_validators(): # Check initialization works demo = ArrayValidatorDemoClass() + demo.set_grid(3) demo.validate() # Check assignment with correct shape: 3 x 4 x 10 (x 3 x 3)