Skip to content

Commit f2b5fa2

Browse files
committed
refactor: fix type errors in normalization
1 parent 20a3d1f commit f2b5fa2

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

glassure/normalization.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Optional
1+
2+
from typing import Any
23

34
import numpy as np
45
import lmfit
@@ -12,7 +13,7 @@ def normalize(
1213
atomic_density: float,
1314
f_squared_mean: np.ndarray,
1415
f_mean_squared: np.ndarray,
15-
incoherent_scattering: Optional[np.ndarray] = None,
16+
incoherent_scattering: np.ndarray | None = None,
1617
attenuation_factor: float = 0.001,
1718
) -> tuple[float, Pattern]:
1819
"""
@@ -52,11 +53,11 @@ def normalize(
5253
def normalize_fit_lmfit(
5354
sample_pattern: Pattern,
5455
f_squared_mean: np.ndarray,
55-
incoherent_scattering: Optional[np.ndarray] = None,
56+
incoherent_scattering: np.ndarray | None = None,
5657
q_cutoff: float = 3,
5758
method: str = "squared",
5859
multiple_scattering: bool = False,
59-
container_scattering: Optional[np.ndarray] = None,
60+
container_scattering: np.ndarray | None = None,
6061
) -> tuple[lmfit.Parameters, Pattern]:
6162
"""
6263
This function is deprecated and will be removed in the future. It is replaced by a new
@@ -112,8 +113,8 @@ def normalize_fit_lmfit(
112113

113114
# calculate values for integrals
114115
if incoherent_scattering is None:
115-
incoherent_scattering = 0
116-
incoherent_scattering_cut = 0
116+
incoherent_scattering = np.array(0)
117+
incoherent_scattering_cut = np.array(0)
117118
else:
118119
assert len(incoherent_scattering) == len(
119120
q
@@ -149,8 +150,8 @@ def normalize_fit_lmfit(
149150
container_contribution_cut = container_contribution[q_ind]
150151
else:
151152
params.add("n_container", value=0, vary=False)
152-
container_contribution = 0
153-
container_contribution_cut = 0
153+
container_contribution = np.array(0)
154+
container_contribution_cut = np.array(0)
154155

155156
def optimization_fcn(params):
156157
n = params["n"].value
@@ -166,7 +167,7 @@ def optimization_fcn(params):
166167
theory = f_squared_mean_cut + compton
167168
return ((n * intensity_cut - multiple - theory) * scaling) ** 2
168169

169-
out = lmfit.minimize(optimization_fcn, params)
170+
out: Any = lmfit.minimize(optimization_fcn, params)
170171

171172
# prepare final output
172173
q_out = sample_pattern.x
@@ -183,11 +184,11 @@ def optimization_fcn(params):
183184
def normalize_fit(
184185
sample_pattern: Pattern,
185186
f_squared_mean: np.ndarray,
186-
incoherent_scattering: Optional[np.ndarray] = None,
187+
incoherent_scattering: np.ndarray | None = None,
187188
q_cutoff: float = 3,
188189
method: str = "squared",
189190
multiple_scattering: bool = False,
190-
container_scattering: Optional[np.ndarray] = None,
191+
container_scattering: np.ndarray | None = None,
191192
) -> tuple[dict, Pattern]:
192193
"""
193194
Estimates the normalization factor n for calculating S(Q) by solving the linear least squares problem

0 commit comments

Comments
 (0)