Skip to content

Commit

Permalink
Merge pull request #120 from yucongalicechen/xtype
Browse files Browse the repository at this point in the history
Implement xtype other than tth
  • Loading branch information
sbillinge authored Nov 1, 2024
2 parents 79a32e7 + b558f18 commit 651248d
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 25 deletions.
23 changes: 23 additions & 0 deletions news/xtype.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
**Added:**

* Support for independent variables other than two-theta.

**Changed:**

* <news item>

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* <news item>

**Security:**

* <news item>
23 changes: 11 additions & 12 deletions src/diffpy/labpdfproc/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pandas as pd
from scipy.interpolate import interp1d

from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object

RADIUS_MM = 1
N_POINTS_ON_DIAMETER = 300
Expand Down Expand Up @@ -198,7 +198,6 @@ def _cve_brute_force(diffraction_data, mud):
"tth",
metadata=diffraction_data.metadata,
name=f"absorption correction, cve, for {diffraction_data.name}",
wavelength=diffraction_data.wavelength,
scat_quantity="cve",
)
return cve_do
Expand Down Expand Up @@ -227,7 +226,6 @@ def _cve_polynomial_interpolation(diffraction_data, mud):
"tth",
metadata=diffraction_data.metadata,
name=f"absorption correction, cve, for {diffraction_data.name}",
wavelength=diffraction_data.wavelength,
scat_quantity="cve",
)
return cve_do
Expand All @@ -246,15 +244,18 @@ def _cve_method(method):
return methods[method]


def compute_cve(diffraction_data, mud, method="polynomial_interpolation"):
def compute_cve(diffraction_data, mud, method="polynomial_interpolation", xtype="tth"):
f"""
compute and interpolate the cve for the given diffraction data and mud using the selected method
Parameters
----------
diffraction_data Diffraction_object
the diffraction pattern
mud float
the mu*D of the diffraction object, where D is the diameter of the circle
xtype str
the quantity on the independent variable axis, allowed values are {*XQUANTITIES, }
method str
the method used to calculate cve, must be one of {* CVE_METHODS, }
Expand All @@ -264,22 +265,20 @@ def compute_cve(diffraction_data, mud, method="polynomial_interpolation"):
"""

cve_function = _cve_method(method)
abdo_on_global_tth = cve_function(diffraction_data, mud)
global_tth = abdo_on_global_tth.on_tth[0]
cve_on_global_tth = abdo_on_global_tth.on_tth[1]
orig_grid = diffraction_data.on_tth[0]
newcve = np.interp(orig_grid, global_tth, cve_on_global_tth)
cve_do_on_global_grid = cve_function(diffraction_data, mud)
orig_grid = diffraction_data.on_xtype(xtype)[0]
global_xtype = cve_do_on_global_grid.on_xtype(xtype)[0]
cve_on_global_xtype = cve_do_on_global_grid.on_xtype(xtype)[1]
newcve = np.interp(orig_grid, global_xtype, cve_on_global_xtype)
cve_do = Diffraction_object(wavelength=diffraction_data.wavelength)
cve_do.insert_scattering_quantity(
orig_grid,
newcve,
"tth",
xtype,
metadata=diffraction_data.metadata,
name=f"absorption correction, cve, for {diffraction_data.name}",
wavelength=diffraction_data.wavelength,
scat_quantity="cve",
)

return cve_do


Expand Down
11 changes: 5 additions & 6 deletions src/diffpy/labpdfproc/labpdfprocapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ def get_args(override_cli_inputs=None):
help=(
f"The quantity on the independent variable axis. Allowed "
f"values: {*XQUANTITIES, }. If not specified then two-theta "
f"is assumed for the independent variable. Only implemented for "
f"tth currently."
f"is assumed for the independent variable."
),
default="tth",
)
Expand Down Expand Up @@ -160,19 +159,19 @@ def main():
input_pattern.insert_scattering_quantity(
xarray,
yarray,
"tth",
args.xtype,
scat_quantity="x-ray",
name=filepath.stem,
metadata=load_metadata(args, filepath),
)

absorption_correction = compute_cve(input_pattern, args.mud, args.method)
absorption_correction = compute_cve(input_pattern, args.mud, args.method, args.xtype)
corrected_data = apply_corr(input_pattern, absorption_correction)
corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}"
corrected_data.dump(f"{outfile}", xtype="tth")
corrected_data.dump(f"{outfile}", xtype=args.xtype)

if args.output_correction:
absorption_correction.dump(f"{corrfile}", xtype="tth")
absorption_correction.dump(f"{corrfile}", xtype=args.xtype)


if __name__ == "__main__":
Expand Down
21 changes: 21 additions & 0 deletions src/diffpy/labpdfproc/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path

from diffpy.labpdfproc.mud_calculator import compute_mud
from diffpy.utils.scattering_objects.diffraction_objects import QQUANTITIES, XQUANTITIES
from diffpy.utils.tools import get_package_info, get_user_info

WAVELENGTHS = {"Mo": 0.71, "Ag": 0.59, "Cu": 1.54}
Expand Down Expand Up @@ -138,6 +139,25 @@ def set_wavelength(args):
return args


def set_xtype(args):
f"""
Set the xtype based on the given input arguments, raise an error if xtype is not one of {*XQUANTITIES, }

Parameters
----------
args argparse.Namespace
the arguments from the parser

Returns
-------
args argparse.Namespace
"""
if args.xtype.lower() not in XQUANTITIES:
raise ValueError(f"Unknown xtype: {args.xtype}. Allowed xtypes are {*XQUANTITIES, }.")
args.xtype = "q" if args.xtype.lower() in QQUANTITIES else "tth"
return args


def set_mud(args):
"""
Set the mud based on the given input arguments
Expand Down Expand Up @@ -260,6 +280,7 @@ def preprocessing_args(args):
args = set_input_lists(args)
args.output_directory = set_output_directory(args)
args = set_wavelength(args)
args = set_xtype(args)
args = set_mud(args)
args = load_user_metadata(args)
return args
Expand Down
22 changes: 15 additions & 7 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,29 +58,37 @@ def test_set_muls_at_angle(inputs, expected):
assert actual_muls_sorted == pytest.approx(expected_muls_sorted, rel=1e-4, abs=1e-6)


def _instantiate_test_do(xarray, yarray, name="test", scat_quantity="x-ray"):
def _instantiate_test_do(xarray, yarray, xtype="tth", name="test", scat_quantity="x-ray"):
test_do = Diffraction_object(wavelength=1.54)
test_do.insert_scattering_quantity(
xarray,
yarray,
"tth",
xtype,
scat_quantity=scat_quantity,
name=name,
metadata={"thing1": 1, "thing2": "thing2"},
)
return test_do


def test_compute_cve(mocker):
params4 = [
(["tth"], [np.array([90, 90.1, 90.2]), np.array([0.5, 0.5, 0.5]), "tth"]),
(["q"], [np.array([5.76998, 5.77501, 5.78004]), np.array([0.5, 0.5, 0.5]), "q"]),
]


@pytest.mark.parametrize("inputs, expected", params4)
def test_compute_cve(inputs, expected, mocker):
xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2])
expected_cve = np.array([0.5, 0.5, 0.5])
mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray)
mocker.patch("numpy.interp", return_value=expected_cve)
input_pattern = _instantiate_test_do(xarray, yarray)
actual_cve_do = compute_cve(input_pattern, mud=1)
actual_cve_do = compute_cve(input_pattern, mud=1, method="polynomial_interpolation", xtype=inputs[0])
expected_cve_do = _instantiate_test_do(
xarray,
expected_cve,
expected[0],
expected[1],
expected[2],
name="absorption correction, cve, for test",
scat_quantity="cve",
)
Expand All @@ -92,7 +100,7 @@ def test_compute_cve(mocker):
[7, "polynomial_interpolation"],
[
f"mu*D is out of the acceptable range (0.5 to 6) for polynomial interpolation. "
f"Please rerun with a value within this range or specifying another method from {* CVE_METHODS, }."
f"Please rerun with a value within this range or specifying another method from {*CVE_METHODS, }."
],
),
([1, "invalid_method"], [f"Unknown method: invalid_method. Allowed methods are {*CVE_METHODS, }."]),
Expand Down
27 changes: 27 additions & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
set_mud,
set_output_directory,
set_wavelength,
set_xtype,
)
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES

# Use cases can be found here: https://github.com/diffpy/diffpy.labpdfproc/issues/48

Expand Down Expand Up @@ -189,6 +191,31 @@ def test_set_wavelength_bad(inputs, msg):
actual_args = set_wavelength(actual_args)


params4 = [
([], ["tth"]),
(["--xtype", "2theta"], ["tth"]),
(["--xtype", "d"], ["tth"]),
(["--xtype", "q"], ["q"]),
]


@pytest.mark.parametrize("inputs, expected", params4)
def test_set_xtype(inputs, expected):
cli_inputs = ["2.5", "data.xy"] + inputs
actual_args = get_args(cli_inputs)
actual_args = set_xtype(actual_args)
assert actual_args.xtype == expected[0]


def test_set_xtype_bad():
cli_inputs = ["2.5", "data.xy", "--xtype", "invalid"]
actual_args = get_args(cli_inputs)
with pytest.raises(
ValueError, match=re.escape(f"Unknown xtype: invalid. Allowed xtypes are {*XQUANTITIES, }.")
):
actual_args = set_xtype(actual_args)


def test_set_mud(user_filesystem):
cli_inputs = ["2.5", "data.xy"]
actual_args = get_args(cli_inputs)
Expand Down

0 comments on commit 651248d

Please sign in to comment.