diff --git a/solar_apparent_time/solar_apparent_time.py b/solar_apparent_time/solar_apparent_time.py index a824ea3..879fa42 100644 --- a/solar_apparent_time/solar_apparent_time.py +++ b/solar_apparent_time/solar_apparent_time.py @@ -62,9 +62,33 @@ def _broadcast_time_and_space(times: np.ndarray, lons: np.ndarray) -> tuple[np.n return np.broadcast_arrays(times[..., None], lons) +def extract_lat_lon(geometry: Union[SpatialGeometry, GeoSeries]) -> SpatialGeometry: + """ + Extract the SpatialGeometry from a RasterGeometry or GeoSeries. + + Parameters + ---------- + geometry : SpatialGeometry or GeoSeries + The geometry object to extract from. + + Returns + ------- + SpatialGeometry + The extracted SpatialGeometry. + """ + if isinstance(geometry, SpatialGeometry): + lat = geometry.lat + lon = geometry.lon + elif isinstance(geometry, GeoSeries): + lat = geometry.y + lon = geometry.x + else: + raise ValueError("geometry must be SpatialGeometry or GeoSeries") + return lat, lon + def calculate_solar_hour_of_day( time_UTC: Union[datetime, str, list, np.ndarray], - geometry: SpatialGeometry = None, + geometry: Union[SpatialGeometry, GeoSeries] = None, lat: Union[np.ndarray, float] = None, lon: Union[np.ndarray, float] = None ) -> np.ndarray: @@ -93,17 +117,17 @@ def calculate_solar_hour_of_day( """ times = _parse_time(time_UTC) - if geometry is not None: - lon = geometry.lon - elif lon is not None: - lon = np.asarray(lon) + if lat is None or lon is None and geometry is not None: + lat, lon = extract_lat_lon(geometry) + + times = np.asarray(times) + lon = np.asarray(lon) + if times.ndim == 1 and lon.ndim == 1 and times.shape == lon.shape: + times_b = times + lons_b = lon else: - raise ValueError('Must provide either spatial or lon.') - - # Broadcast times and lons - times_b, lons_b = _broadcast_time_and_space(times, lon) + times_b, lons_b = _broadcast_time_and_space(times, lon) - # Calculate hour_UTC hour_UTC = ( times_b.astype('datetime64[h]').astype(int) % 24 + (times_b.astype('datetime64[m]').astype(int) % 60) / 60 @@ -119,7 +143,7 @@ def calculate_solar_hour_of_day( def calculate_solar_day_of_year( time_UTC: Union[datetime, str, list, np.ndarray], - geometry: SpatialGeometry = None, + geometry: Union[SpatialGeometry, GeoSeries] = None, lat: Union[np.ndarray, float] = None, lon: Union[np.ndarray, float] = None ) -> np.ndarray: @@ -148,20 +172,23 @@ def calculate_solar_day_of_year( """ times = _parse_time(time_UTC) - # If latitude is not provided, try to extract from geometry - if lat is None and isinstance(geometry, SpatialGeometry): - lat = geometry.lat - elif lat is None and isinstance(geometry, GeoSeries): - lat = geometry.y - elif lat is None: - raise ValueError("no latitude provided") - - if lon is None and isinstance(geometry, SpatialGeometry): - lon = geometry.lon - elif lon is None and isinstance(geometry, GeoSeries): - lon = geometry.x - elif lon is None: - raise ValueError("no longitude provided") + # # If latitude is not provided, try to extract from geometry + # if lat is None and isinstance(geometry, SpatialGeometry): + # lat = geometry.lat + # elif lat is None and isinstance(geometry, GeoSeries): + # lat = geometry.y + # elif lat is None: + # raise ValueError("no latitude provided") + + # if lon is None and isinstance(geometry, SpatialGeometry): + # lon = geometry.lon + # elif lon is None and isinstance(geometry, GeoSeries): + # lon = geometry.x + # elif lon is None: + # raise ValueError("no longitude provided") + + if lat is None or lon is None and geometry is not None: + lat, lon = extract_lat_lon(geometry) # Handle 1D time and lon inputs of the same length: pair element-wise times = np.asarray(times)