Skip to content

Commit 5d07718

Browse files
committed
refactor: more clean up of test optimization
1 parent 36dc0bc commit 5d07718

File tree

4 files changed

+95
-46
lines changed

4 files changed

+95
-46
lines changed

glassure/calc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def create_calculate_pdf_configs(
6969

7070
def calculate_pdf(
7171
data_config: DataConfig, calculation_config: CalculationConfig
72-
) -> Pattern:
72+
) -> Result:
7373
"""
7474
Process the input configuration and return the result.
7575
"""

glassure/optimization.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from . import Pattern
1010
from .transform import calculate_fr, calculate_gr, calculate_sq_from_fr
11+
from .utility import convert_density_to_atoms_per_cubic_angstrom
1112

1213
__all__ = [
1314
"optimize_sq",
@@ -20,6 +21,7 @@ def optimize_sq(
2021
r_cutoff: float,
2122
iterations: int,
2223
atomic_density: float,
24+
r_step: float = 0.01,
2325
use_modification_fcn: bool = False,
2426
attenuation_factor: float = 1,
2527
fcn_callback=None,
@@ -39,6 +41,9 @@ def optimize_sq(
3941
number of back and forward transforms
4042
:param atomic_density:
4143
density in atoms/A^3
44+
:param r_step:
45+
step size for the r-axis, default is 0.01. Use smaller values for better accuracy (especially if needed for
46+
fft)
4247
:param use_modification_fcn:
4348
Whether to use the Lorch modification function during the Fourier transform.
4449
Warning: When using the Lorch modification function, usually more iterations are needed to get to the
@@ -59,7 +64,7 @@ def optimize_sq(
5964
:return:
6065
optimized S(Q) pattern
6166
"""
62-
r = np.arange(0, r_cutoff, 0.01)
67+
r = np.arange(0, r_cutoff, r_step)
6368
sq_pattern = deepcopy(sq_pattern)
6469
for iteration in range(iterations):
6570
fr_pattern = calculate_fr(
@@ -71,13 +76,14 @@ def optimize_sq(
7176
delta_fr = fr_int + 4 * np.pi * r * atomic_density
7277

7378
if fourier_transform_method == "fft":
74-
sq_trans_fft = calculate_sq_from_fr(
75-
Pattern(r, delta_fr), sq_pattern.x, method="fft"
76-
) - 1
79+
sq_trans_fft = (
80+
calculate_sq_from_fr(Pattern(r, delta_fr), sq_pattern.x, method="fft")
81+
- 1
82+
)
7783
iq = sq_trans_fft.y
7884
else:
7985
in_integral = np.array(np.sin(np.outer(q.T, r))) * delta_fr
80-
iq = simpson(in_integral, r) / q
86+
iq = simpson(in_integral, x=r) / q
8187

8288
sq_pattern = sq_pattern * (1 - iq / attenuation_factor)
8389

@@ -177,6 +183,10 @@ def fcn(params):
177183
if type == "gr":
178184
r, gr = result.gr.limit(*min_range).data
179185
residual = np.trapz(gr**2, r)
186+
if type == "fr":
187+
atomic_density = optim_config.sample.atomic_density
188+
r, fr = result.fr.limit(*min_range).data
189+
residual = np.trapz(fr + 4 * np.pi * r * atomic_density)
180190
elif type == "sq":
181191
q, sq = result.sq.limit(*min_range).data
182192
sq_ref = reference_result.sq.limit(*min_range).y

tests/test_normalization.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
)
1212
from glassure.normalization import normalize, normalize_fit, normalize_fit_lmfit
1313
from glassure.scattering_factors import calculate_coherent_scattering_factor
14-
from glassure.transform import calculate_sq
1514

1615
from . import unittest_data_path
1716

tests/test_optimization.py

Lines changed: 79 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22

33
import os
4-
import unittest
4+
import pytest
55
import numpy as np
66

77
from glassure import Pattern, convert_density_to_atoms_per_cubic_angstrom
@@ -14,7 +14,7 @@
1414
calculate_s0,
1515
)
1616
from glassure.transform import calculate_sq, calculate_fr
17-
from glassure.normalization import normalize_fit
17+
from glassure.normalization import normalize_fit, normalize
1818
from glassure.configuration import OptimizeConfig
1919
from glassure.optimization import optimize_sq, optimize_density
2020
from glassure.calc import calculate_pdf, create_calculate_pdf_configs
@@ -32,58 +32,98 @@
3232
background_path_SiO2 = os.path.join(unittest_data_path, "SiO2_bkg.xy")
3333

3434

35-
def test_optimize_sq():
36-
data = Pattern.from_file(data_path_alloy)
37-
background = Pattern.from_file(background_path_alloy)
38-
composition = {"Fe": 0.81, "S": 0.19}
39-
density = 7.9
40-
atomic_density = convert_density_to_atoms_per_cubic_angstrom(composition, density)
41-
f_squared_mean = calculate_f_squared_mean(composition, data.x)
42-
f_mean_squared = calculate_f_mean_squared(composition, data.x)
43-
incoherent_scattering = calculate_incoherent_scattering(composition, data.x)
44-
background_scaling = 0.97
35+
@pytest.fixture
36+
def data_path():
37+
"""Path to the test data file."""
38+
return os.path.join(unittest_data_path, "SiO2.xy")
4539

46-
sample_pattern = data - background_scaling * background
47-
48-
sq = calculate_sq(sample_pattern, f_squared_mean, f_mean_squared)
49-
sq = extrapolate_to_zero_poly(sq, np.min(sq.x) + 0.3)
50-
sq_optimized = optimize_sq(sq, 1.6, 5, atomic_density)
51-
assert not np.allclose(sq.y, sq_optimized.y)
5240

41+
@pytest.fixture
42+
def bkg_path():
43+
"""Path to the background data file."""
44+
return os.path.join(unittest_data_path, "SiO2_bkg.xy")
5345

54-
def test_optimize_sq_fft():
55-
data = Pattern.from_file(data_path_SiO2)
56-
background = Pattern.from_file(background_path_SiO2)
57-
composition = {"Si": 1, "O": 2}
58-
density = 2.2
59-
atomic_density = convert_density_to_atoms_per_cubic_angstrom(composition, density)
60-
background_scaling = 1.0
6146

62-
sample_pattern = data - background_scaling * background
63-
sample_pattern = sample_pattern.limit(0, 17).rebin(0.05)
64-
q = sample_pattern.x
65-
f_squared_mean = calculate_f_squared_mean(composition, q)
66-
f_mean_squared = calculate_f_mean_squared(composition, q)
67-
incoherent_scattering = calculate_incoherent_scattering(composition, q)
47+
@pytest.fixture
48+
def sample(data_path, bkg_path):
49+
"""Create a sample pattern by subtracting background from data."""
50+
data = Pattern.from_file(data_path)
51+
bkg = Pattern.from_file(bkg_path)
52+
sample = data - bkg
53+
return sample.limit(1, 17)
6854

69-
_, norm_pattern = normalize_fit(
70-
sample_pattern, f_squared_mean, incoherent_scattering
71-
)
7255

73-
sq = calculate_sq(norm_pattern, f_squared_mean, f_mean_squared)
56+
@pytest.fixture
57+
def sq(normalized_pattern, f_squared_mean, f_mean_squared, composition):
58+
"""Create a sq pattern for testing."""
59+
sq = calculate_sq(normalized_pattern, f_squared_mean, f_mean_squared)
7460
sq = extrapolate_to_zero_linear(sq, y0=calculate_s0(composition))
61+
sq = sq.rebin(0.05)
7562
sq.x[0] = 1e-10
76-
iterations = 5
63+
return sq
64+
65+
66+
@pytest.fixture
67+
def composition():
68+
"""Sample composition for testing."""
69+
return {"Si": 1, "O": 2}
70+
71+
72+
@pytest.fixture
73+
def density():
74+
"""Sample density for testing."""
75+
return 2.2
76+
77+
78+
@pytest.fixture
79+
def atomic_density(composition, density):
80+
"""Calculate atomic density from composition and density."""
81+
return convert_density_to_atoms_per_cubic_angstrom(composition, density)
82+
83+
84+
@pytest.fixture
85+
def f_squared_mean(composition, sample):
86+
"""Calculate f squared mean for the sample."""
87+
return calculate_f_squared_mean(composition, sample.x)
7788

78-
fr = calculate_fr(sq, method="fft")
7989

90+
@pytest.fixture
91+
def f_mean_squared(composition, sample):
92+
"""Calculate f mean squared for the sample."""
93+
return calculate_f_mean_squared(composition, sample.x)
94+
95+
96+
@pytest.fixture
97+
def incoherent_scattering(composition, sample):
98+
"""Calculate incoherent scattering for the sample."""
99+
return calculate_incoherent_scattering(composition, sample.x)
100+
101+
102+
@pytest.fixture
103+
def normalized_pattern(sample, f_squared_mean, incoherent_scattering):
104+
"""Create normalized pattern for testing."""
105+
_, normalized = normalize_fit(
106+
sample, f_squared_mean, incoherent_scattering, q_cutoff=10
107+
)
108+
return normalized
109+
110+
111+
def test_optimize_sq(sq, atomic_density):
112+
sq_optimized = optimize_sq(sq, 1.4, 5, atomic_density)
113+
assert not np.allclose(sq.y, sq_optimized.y)
114+
115+
116+
def test_optimize_sq_fft(sq, atomic_density):
117+
iterations = 5
118+
r_step = 0.001 # need high value to be accurate for fft
80119
sq_optimized = optimize_sq(
81-
sq, 1.3, iterations, atomic_density, fourier_transform_method="integral"
120+
sq, 1.3, iterations, atomic_density, fourier_transform_method="integral", r_step=r_step
82121
)
83122
fr_optimized = calculate_fr(sq_optimized, method="fft")
84123

124+
85125
sq_optimized_fft = optimize_sq(
86-
sq, 1.3, iterations, atomic_density, fourier_transform_method="fft"
126+
sq, 1.3, iterations, atomic_density, fourier_transform_method="fft", r_step=r_step
87127
)
88128
fr_optimized_fft = calculate_fr(sq_optimized_fft, method="fft")
89129

0 commit comments

Comments
 (0)