Skip to content

Commit 20a3d1f

Browse files
committed
refactor: fix more type issues
1 parent ca85050 commit 20a3d1f

File tree

4 files changed

+82
-40
lines changed

4 files changed

+82
-40
lines changed

glassure/calc.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def create_calculate_pdf_configs(
3737
data: Pattern,
3838
composition: Composition,
3939
density: float,
40-
bkg: Pattern = None,
40+
bkg: Pattern | None = None,
4141
bkg_scaling: float = 1,
4242
) -> tuple[DataConfig, CalculationConfig]:
4343
"""
@@ -79,6 +79,9 @@ def calculate_pdf(
7979
config = calculation_config
8080
transform = config.transform
8181
composition = config.sample.composition
82+
sample_atomic_density = config.sample.atomic_density
83+
if sample_atomic_density is None:
84+
raise ValueError("Sample atomic density must be provided for PDF calculation.")
8285

8386
# subtract background
8487
if data_config.bkg is not None:
@@ -150,7 +153,7 @@ def calculate_pdf(
150153

151154
n, norm = normalize(
152155
sample_pattern=sample,
153-
atomic_density=config.sample.atomic_density,
156+
atomic_density=sample_atomic_density,
154157
f_squared_mean=f_squared_mean,
155158
f_mean_squared=f_mean_squared,
156159
incoherent_scattering=norm_inc,
@@ -207,7 +210,7 @@ def calculate_pdf(
207210
opt = config.optimize
208211
sq = optimize_sq(
209212
sq,
210-
atomic_density=config.sample.atomic_density,
213+
atomic_density=sample_atomic_density,
211214
r_cutoff=opt.r_cutoff,
212215
r_step=transform.r_step,
213216
iterations=opt.iterations,
@@ -228,7 +231,7 @@ def calculate_pdf(
228231

229232
gr = calculate_gr(
230233
fr,
231-
atomic_density=config.sample.atomic_density,
234+
atomic_density=sample_atomic_density,
232235
)
233236

234237
res = Result(

glassure/configuration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@ class SampleConfig(BaseModel):
3434
)
3535
@property
3636
def atomic_density(self) -> Optional[float]:
37-
if self.composition == {}: # empty composition
37+
if self.composition == {} or self.density is None: # empty composition or density is not set
3838
return None
3939
return convert_density_to_atoms_per_cubic_angstrom(
4040
self.composition, self.density
4141
)
4242

4343
@atomic_density.setter
4444
def atomic_density(self, value: Optional[float]):
45-
if self.composition == {}:
45+
if self.composition == {} or value is None:
4646
self.density = None
4747
else:
4848
self.density = convert_density_to_grams_per_cubic_centimeter(

glassure/optimization.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22

33
from copy import deepcopy
4-
from typing import Optional
4+
from typing import Any
55

66
import numpy as np
77
from lmfit import Parameters, minimize
@@ -179,14 +179,14 @@ def fit_polynom_through_origin(x, y , degree: int) -> np.ndarray:
179179

180180

181181
from .calc import calculate_pdf
182-
from .configuration import CalculationConfig, DataConfig
182+
from .configuration import CalculationConfig, DataConfig, Result
183183

184184

185185
def optimize_density(
186186
data_config: DataConfig,
187187
calculation_config: CalculationConfig,
188188
method: str = "fr",
189-
min_range: Optional[tuple[float, float]] = None,
189+
min_range: tuple[float, float] | None = None,
190190
vary_bkg_scaling: bool = True,
191191
bkg_limits: tuple[float, float] = (0.9, 1.1),
192192
optimization_method: str = "lsq",
@@ -272,17 +272,30 @@ def optimize_density(
272272
)
273273

274274
optim_config = calculation_config.model_copy(deep=True)
275+
reference_result: Result | None = None
276+
range_limits: tuple[float, float] | None = min_range
275277

276278
if method == "sq":
277279
reference_config = calculation_config.model_copy(deep=True)
278280
reference_config.optimize = None
279281
reference_result = calculate_pdf(data_config, reference_config)
280-
if min_range is None:
281-
min_range = (0, reference_config.transform.q_max)
282+
if range_limits is None:
283+
range_limits = (0, reference_config.transform.q_max)
284+
elif method in ("gr", "fr"):
285+
if range_limits is None:
286+
optimize_settings = optim_config.optimize
287+
if optimize_settings is None:
288+
raise ValueError(
289+
"Optimization range cannot be inferred because calculation_config.optimize is None."
290+
)
291+
range_limits = (0, optimize_settings.r_cutoff)
292+
else:
293+
raise ValueError(
294+
f"Invalid optimize density method: {method}, only 'gr', 'fr' and 'sq' are supported."
295+
)
282296

283-
elif method == "gr" or method == "fr":
284-
if min_range is None:
285-
min_range = (0, optim_config.optimize.r_cutoff)
297+
if range_limits is None:
298+
raise ValueError("Optimization range must be specified.")
286299

287300
def fcn(params):
288301
density = params["density"].value
@@ -292,46 +305,52 @@ def fcn(params):
292305
result = calculate_pdf(data_config, optim_config)
293306

294307
if method == "gr":
295-
r, gr = result.gr.limit(*min_range).data
308+
if result.gr is None:
309+
raise ValueError("Result does not contain g(r) data required for 'gr' optimization.")
310+
r, gr = result.gr.limit(*range_limits).data
296311
residual = gr * (r[1] - r[0])
297312
elif method == "fr":
313+
if result.fr is None:
314+
raise ValueError("Result does not contain F(r) data required for 'fr' optimization.")
298315
atomic_density = optim_config.sample.atomic_density
299-
r, fr = result.fr.limit(*min_range).data
316+
if atomic_density is None:
317+
raise ValueError("Sample atomic density must be set for 'fr' optimization.")
318+
r, fr = result.fr.limit(*range_limits).data
300319
residual = (fr + 4 * np.pi * r * atomic_density) * (r[1] - r[0])
301320
elif method == "sq":
302-
q, sq = result.sq.limit(*min_range).data
303-
sq_ref = reference_result.sq.limit(*min_range).y
321+
if reference_result is None or reference_result.sq is None:
322+
raise ValueError("Reference result does not contain S(q) data required for 'sq' optimization.")
323+
if result.sq is None:
324+
raise ValueError("Result does not contain S(q) data required for 'sq' optimization.")
325+
q, sq = result.sq.limit(*range_limits).data
326+
sq_ref = reference_result.sq.limit(*range_limits).y
304327
residual = (sq - sq_ref) * (q[1] - q[0])
305-
else:
306-
raise ValueError(
307-
f"Invalid optimize density method: {method}, only 'gr', 'fr' and 'sq' are supported."
308-
)
309328
return residual
310329

311330
if optimization_method == "nelder":
312-
res = minimize(
331+
nelder_res: Any = minimize(
313332
fcn,
314333
params,
315334
method="nelder",
316335
options={"maxfev": 500, "fatol": 0.0001, "xatol": 0.0001},
317336
)
318337
return (
319-
res.params["density"].value,
320-
np.sum(res.residual**2),
321-
res.params["bkg_scaling"].value,
322-
np.sum(res.residual**2),
338+
nelder_res.params["density"].value,
339+
np.sum(nelder_res.residual**2),
340+
nelder_res.params["bkg_scaling"].value,
341+
np.sum(nelder_res.residual**2),
323342
)
324343
elif optimization_method == "lsq":
325-
res = minimize(
344+
lsq_res: Any = minimize(
326345
fcn,
327346
params,
328347
method="least_squares",
329348
)
330349
return (
331-
res.params["density"].value,
332-
res.params["density"].stderr,
333-
res.params["bkg_scaling"].value,
334-
res.params["bkg_scaling"].stderr,
350+
lsq_res.params["density"].value,
351+
lsq_res.params["density"].stderr,
352+
lsq_res.params["bkg_scaling"].value,
353+
lsq_res.params["bkg_scaling"].stderr,
335354
)
336355
else:
337356
raise ValueError(f"Invalid optimization method: {optimization_method}")

glassure/pattern.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from __future__ import annotations
33
import os
44
from typing_extensions import Annotated
5-
from typing import Union, Optional, TYPE_CHECKING
5+
from typing import Union, Optional, TYPE_CHECKING, Sequence
66
from pydantic import PlainSerializer, PlainValidator
77
import numpy as np
88
import base64
@@ -304,21 +304,22 @@ def __radd__(self, other: Union[float, Pattern]) -> Pattern:
304304
"""
305305
return self.__add__(other)
306306

307-
def __rmul__(self, other: float) -> Pattern:
307+
def __rmul__(self, other: Union[float, np.ndarray, Sequence[float]]) -> Pattern:
308308
"""
309-
Multiplies the pattern with a scalar.
309+
Multiplies the pattern with a scalar or an array-like of the same shape as the y-values.
310310
311-
:param other: scalar to multiply with
311+
:param other: scalar or array-like to multiply with
312312
:return: new Pattern
313313
"""
314314
orig_x, orig_y = self.data
315-
return Pattern(np.copy(orig_x), np.copy(orig_y) * other)
315+
multiplier = self._normalize_multiplier(other, orig_y.shape)
316+
return Pattern(np.copy(orig_x), np.multiply(orig_y, multiplier))
316317

317-
def __mul__(self, other: float) -> Pattern:
318+
def __mul__(self, other: Union[float, np.ndarray, Sequence[float]]) -> Pattern:
318319
"""
319-
Multiplies the pattern with a scalar.
320+
Multiplies the pattern with a scalar or an array-like of the same shape as the y-values.
320321
321-
:param other: scalar to multiply with
322+
:param other: scalar or array-like to multiply with
322323
:return: new Pattern
323324
"""
324325
return self.__rmul__(other)
@@ -337,6 +338,25 @@ def __eq__(self, other: object) -> bool:
337338
return True
338339
return False
339340

341+
@staticmethod
342+
def _normalize_multiplier(
343+
multiplier: Union[float, np.ndarray, Sequence[float]], target_shape: tuple[int, ...]
344+
) -> np.ndarray:
345+
"""
346+
Normalizes the multiplier to a float or an array of the same shape as the target shape.
347+
:param multiplier: The multiplier to normalize
348+
:param target_shape: The shape of the target array
349+
:return: The normalized multiplier
350+
"""
351+
array_multiplier = np.asarray(multiplier, dtype=float)
352+
if array_multiplier.ndim == 0:
353+
return array_multiplier
354+
if array_multiplier.shape != target_shape:
355+
raise ValueError(
356+
"Array multiplier must have the same shape as the pattern's y values."
357+
)
358+
return array_multiplier
359+
340360

341361
class BkgNotInRangeError(Exception):
342362
def __init__(self, pattern_name: str):

0 commit comments

Comments
 (0)