Skip to content

Commit 4caf82e

Browse files
committed
refactor: final typing fixes, conforming to Type Checking = Standard in Pyright
1 parent 965e657 commit 4caf82e

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

glassure/utility.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
"calculate_weighting_factor",
3535
]
3636

37-
Composition = Dict[str, Union[int, float]]
37+
Composition = dict[str, float] | dict[str, int]
3838

3939

4040
def parse_str_to_composition(formula: str) -> Composition:
@@ -341,8 +341,10 @@ def calculate_wkm_effective_atomic_number(
341341
return float(np.mean(form_factor / effective_form_factor))
342342

343343

344-
def calculate_total_atomic_number(composition: Composition) -> float:
345-
"""
344+
def calculate_total_atomic_number(
345+
composition: Composition | dict[str, float] | dict[str, int],
346+
) -> float:
347+
r"""
346348
Calculates the total atomic number of a given composition.
347349
"""
348350
return sum(
@@ -351,9 +353,11 @@ def calculate_total_atomic_number(composition: Composition) -> float:
351353

352354

353355
def calculate_effective_form_factor(
354-
composition: Composition, q: np.ndarray, sf_source="hajdu"
356+
composition: Composition,
357+
q: np.ndarray,
358+
sf_source="hajdu",
355359
) -> np.ndarray:
356-
"""
360+
r"""
357361
Calculates the effective form factor for a given composition, which is given by
358362
359363
.. math::
@@ -389,7 +393,7 @@ def normalize_composition(composition: Composition) -> dict[str, float]:
389393
for key, val in composition.items():
390394
sum += val
391395

392-
result = copy(composition)
396+
result: dict[str, float] = {key: float(val) for key, val in composition.items()}
393397

394398
for key in result:
395399
result[key] /= sum

tests/test_pattern.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_pydantic_nparray_with_list_input():
146146

147147
def test_pydantic_nparray_from_json():
148148
input = {"x": [1, 2, 3]}
149-
t = DummyModel(**input)
149+
t = DummyModel(**input) # type: ignore[arg-type]
150150
assert np.array_equal(t.x, np.array([1, 2, 3]))
151151

152152

tests/test_utility.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def test_linear_extrapolation(self):
142142

143143
x_linear = x1[x1 < 1]
144144
y_linear = y1[x1 < 1]
145-
self.assertAlmostEqual(np.sum(y_linear - x_linear), 0)
145+
self.assertAlmostEqual(float(np.sum(y_linear - x_linear)), 0)
146146

147147
def test_linear_extrapolation_with_different_y(self):
148148
x = np.arange(1, 5.05, 0.05)
@@ -197,7 +197,7 @@ def test_extrapolate_to_zero_poly(self):
197197
y_expected = a * (x_extrapolate - c) + b * (x_extrapolate - c) ** 2
198198
y_expected[x_extrapolate < c] = 0
199199

200-
self.assertAlmostEqual(np.sum(y_extrapolate - y_expected), 0)
200+
self.assertAlmostEqual(float(np.sum(y_extrapolate - y_expected)), 0.0)
201201

202202
extrapolated_pattern = extrapolate_to_zero_poly(pattern, x_max, replace=True)
203203
x1, y1 = extrapolated_pattern.data
@@ -208,7 +208,7 @@ def test_extrapolate_to_zero_poly(self):
208208
y_expected = a * (x_extrapolate - c) + b * (x_extrapolate - c) ** 2
209209
y_expected[x_extrapolate < c] = 0
210210

211-
self.assertAlmostEqual(np.sum(y_extrapolate - y_expected), 0)
211+
self.assertAlmostEqual(float(np.sum(y_extrapolate - y_expected)), 0)
212212

213213
def test_extrapolate_to_zero_poly_with_different_y(self):
214214
x = np.arange(1, 5.05, 0.05)
@@ -290,9 +290,9 @@ def test_calculate_effective_form_factor(self):
290290
self.assertEqual(len(q), len(effective_form_factor))
291291

292292
effective_form_factor = calculate_effective_form_factor(
293-
composition, 0, sf_source="hajdu"
293+
composition, np.array([0]), sf_source="hajdu"
294294
)
295-
self.assertAlmostEqual(effective_form_factor, 1.0, places=3)
295+
self.assertAlmostEqual(float(effective_form_factor[0]), 1.0, places=3)
296296

297297
def test_calculate_wkm_form_factor(self):
298298
q = 0.0

0 commit comments

Comments
 (0)