diff --git a/news/xtype.rst b/news/xtype.rst new file mode 100644 index 0000000..20ecef3 --- /dev/null +++ b/news/xtype.rst @@ -0,0 +1,23 @@ +**Added:** + +* Support for independent variables other than two-theta. + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/src/diffpy/labpdfproc/functions.py b/src/diffpy/labpdfproc/functions.py index d423468..35d587c 100644 --- a/src/diffpy/labpdfproc/functions.py +++ b/src/diffpy/labpdfproc/functions.py @@ -5,7 +5,7 @@ import pandas as pd from scipy.interpolate import interp1d -from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object +from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object RADIUS_MM = 1 N_POINTS_ON_DIAMETER = 300 @@ -198,7 +198,6 @@ def _cve_brute_force(diffraction_data, mud): "tth", metadata=diffraction_data.metadata, name=f"absorption correction, cve, for {diffraction_data.name}", - wavelength=diffraction_data.wavelength, scat_quantity="cve", ) return cve_do @@ -227,7 +226,6 @@ def _cve_polynomial_interpolation(diffraction_data, mud): "tth", metadata=diffraction_data.metadata, name=f"absorption correction, cve, for {diffraction_data.name}", - wavelength=diffraction_data.wavelength, scat_quantity="cve", ) return cve_do @@ -246,15 +244,18 @@ def _cve_method(method): return methods[method] -def compute_cve(diffraction_data, mud, method="polynomial_interpolation"): +def compute_cve(diffraction_data, mud, method="polynomial_interpolation", xtype="tth"): f""" compute and interpolate the cve for the given diffraction data and mud using the selected method + Parameters ---------- diffraction_data Diffraction_object the diffraction pattern mud float the mu*D of the diffraction object, where D is the diameter of the circle + xtype str + the quantity on the independent variable axis, allowed values are {*XQUANTITIES, } method str the method used to calculate cve, must be one of {* CVE_METHODS, } @@ -264,22 +265,20 @@ def compute_cve(diffraction_data, mud, method="polynomial_interpolation"): """ cve_function = _cve_method(method) - abdo_on_global_tth = cve_function(diffraction_data, mud) - global_tth = abdo_on_global_tth.on_tth[0] - cve_on_global_tth = abdo_on_global_tth.on_tth[1] - orig_grid = diffraction_data.on_tth[0] - newcve = np.interp(orig_grid, global_tth, cve_on_global_tth) + cve_do_on_global_grid = cve_function(diffraction_data, mud) + orig_grid = diffraction_data.on_xtype(xtype)[0] + global_xtype = cve_do_on_global_grid.on_xtype(xtype)[0] + cve_on_global_xtype = cve_do_on_global_grid.on_xtype(xtype)[1] + newcve = np.interp(orig_grid, global_xtype, cve_on_global_xtype) cve_do = Diffraction_object(wavelength=diffraction_data.wavelength) cve_do.insert_scattering_quantity( orig_grid, newcve, - "tth", + xtype, metadata=diffraction_data.metadata, name=f"absorption correction, cve, for {diffraction_data.name}", - wavelength=diffraction_data.wavelength, scat_quantity="cve", ) - return cve_do diff --git a/src/diffpy/labpdfproc/labpdfprocapp.py b/src/diffpy/labpdfproc/labpdfprocapp.py index 1cbb78b..d63af8f 100644 --- a/src/diffpy/labpdfproc/labpdfprocapp.py +++ b/src/diffpy/labpdfproc/labpdfprocapp.py @@ -63,8 +63,7 @@ def get_args(override_cli_inputs=None): help=( f"The quantity on the independent variable axis. Allowed " f"values: {*XQUANTITIES, }. If not specified then two-theta " - f"is assumed for the independent variable. Only implemented for " - f"tth currently." + f"is assumed for the independent variable." ), default="tth", ) @@ -160,19 +159,19 @@ def main(): input_pattern.insert_scattering_quantity( xarray, yarray, - "tth", + args.xtype, scat_quantity="x-ray", name=filepath.stem, metadata=load_metadata(args, filepath), ) - absorption_correction = compute_cve(input_pattern, args.mud, args.method) + absorption_correction = compute_cve(input_pattern, args.mud, args.method, args.xtype) corrected_data = apply_corr(input_pattern, absorption_correction) corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}" - corrected_data.dump(f"{outfile}", xtype="tth") + corrected_data.dump(f"{outfile}", xtype=args.xtype) if args.output_correction: - absorption_correction.dump(f"{corrfile}", xtype="tth") + absorption_correction.dump(f"{corrfile}", xtype=args.xtype) if __name__ == "__main__": diff --git a/src/diffpy/labpdfproc/tools.py b/src/diffpy/labpdfproc/tools.py index f5f31da..02b5451 100644 --- a/src/diffpy/labpdfproc/tools.py +++ b/src/diffpy/labpdfproc/tools.py @@ -2,6 +2,7 @@ from pathlib import Path from diffpy.labpdfproc.mud_calculator import compute_mud +from diffpy.utils.scattering_objects.diffraction_objects import QQUANTITIES, XQUANTITIES from diffpy.utils.tools import get_package_info, get_user_info WAVELENGTHS = {"Mo": 0.71, "Ag": 0.59, "Cu": 1.54} @@ -138,6 +139,25 @@ def set_wavelength(args): return args +def set_xtype(args): + f""" + Set the xtype based on the given input arguments, raise an error if xtype is not one of {*XQUANTITIES, } + + Parameters + ---------- + args argparse.Namespace + the arguments from the parser + + Returns + ------- + args argparse.Namespace + """ + if args.xtype.lower() not in XQUANTITIES: + raise ValueError(f"Unknown xtype: {args.xtype}. Allowed xtypes are {*XQUANTITIES, }.") + args.xtype = "q" if args.xtype.lower() in QQUANTITIES else "tth" + return args + + def set_mud(args): """ Set the mud based on the given input arguments @@ -260,6 +280,7 @@ def preprocessing_args(args): args = set_input_lists(args) args.output_directory = set_output_directory(args) args = set_wavelength(args) + args = set_xtype(args) args = set_mud(args) args = load_user_metadata(args) return args diff --git a/tests/test_functions.py b/tests/test_functions.py index ac086a5..f23c32d 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -58,12 +58,12 @@ def test_set_muls_at_angle(inputs, expected): assert actual_muls_sorted == pytest.approx(expected_muls_sorted, rel=1e-4, abs=1e-6) -def _instantiate_test_do(xarray, yarray, name="test", scat_quantity="x-ray"): +def _instantiate_test_do(xarray, yarray, xtype="tth", name="test", scat_quantity="x-ray"): test_do = Diffraction_object(wavelength=1.54) test_do.insert_scattering_quantity( xarray, yarray, - "tth", + xtype, scat_quantity=scat_quantity, name=name, metadata={"thing1": 1, "thing2": "thing2"}, @@ -71,16 +71,24 @@ def _instantiate_test_do(xarray, yarray, name="test", scat_quantity="x-ray"): return test_do -def test_compute_cve(mocker): +params4 = [ + (["tth"], [np.array([90, 90.1, 90.2]), np.array([0.5, 0.5, 0.5]), "tth"]), + (["q"], [np.array([5.76998, 5.77501, 5.78004]), np.array([0.5, 0.5, 0.5]), "q"]), +] + + +@pytest.mark.parametrize("inputs, expected", params4) +def test_compute_cve(inputs, expected, mocker): xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2]) expected_cve = np.array([0.5, 0.5, 0.5]) mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray) mocker.patch("numpy.interp", return_value=expected_cve) input_pattern = _instantiate_test_do(xarray, yarray) - actual_cve_do = compute_cve(input_pattern, mud=1) + actual_cve_do = compute_cve(input_pattern, mud=1, method="polynomial_interpolation", xtype=inputs[0]) expected_cve_do = _instantiate_test_do( - xarray, - expected_cve, + expected[0], + expected[1], + expected[2], name="absorption correction, cve, for test", scat_quantity="cve", ) @@ -92,7 +100,7 @@ def test_compute_cve(mocker): [7, "polynomial_interpolation"], [ f"mu*D is out of the acceptable range (0.5 to 6) for polynomial interpolation. " - f"Please rerun with a value within this range or specifying another method from {* CVE_METHODS, }." + f"Please rerun with a value within this range or specifying another method from {*CVE_METHODS, }." ], ), ([1, "invalid_method"], [f"Unknown method: invalid_method. Allowed methods are {*CVE_METHODS, }."]), diff --git a/tests/test_tools.py b/tests/test_tools.py index afb7bee..f7c6425 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -16,7 +16,9 @@ set_mud, set_output_directory, set_wavelength, + set_xtype, ) +from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES # Use cases can be found here: https://github.com/diffpy/diffpy.labpdfproc/issues/48 @@ -189,6 +191,31 @@ def test_set_wavelength_bad(inputs, msg): actual_args = set_wavelength(actual_args) +params4 = [ + ([], ["tth"]), + (["--xtype", "2theta"], ["tth"]), + (["--xtype", "d"], ["tth"]), + (["--xtype", "q"], ["q"]), +] + + +@pytest.mark.parametrize("inputs, expected", params4) +def test_set_xtype(inputs, expected): + cli_inputs = ["2.5", "data.xy"] + inputs + actual_args = get_args(cli_inputs) + actual_args = set_xtype(actual_args) + assert actual_args.xtype == expected[0] + + +def test_set_xtype_bad(): + cli_inputs = ["2.5", "data.xy", "--xtype", "invalid"] + actual_args = get_args(cli_inputs) + with pytest.raises( + ValueError, match=re.escape(f"Unknown xtype: invalid. Allowed xtypes are {*XQUANTITIES, }.") + ): + actual_args = set_xtype(actual_args) + + def test_set_mud(user_filesystem): cli_inputs = ["2.5", "data.xy"] actual_args = get_args(cli_inputs)