Skip to content

Commit 5d651b3

Browse files
committed
Clean up relative methods into classes
The classes keep track of the current values of the relative parameters (before they were modified by lmfit). This is necessary for modifying all detectors by the diff of the change. Signed-off-by: Patrick Avery <[email protected]>
1 parent 090402d commit 5d651b3

File tree

8 files changed

+129
-83
lines changed

8 files changed

+129
-83
lines changed

hexrd/fitting/calibration/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from .relative_constraints import RelativeConstraintsType
2+
13
from .instrument import InstrumentCalibrator
24
from .laue import LaueCalibrator
3-
from .lmfit_param_handling import RelativeConstraints
45
from .multigrain import calibrate_instrument_from_sx, generate_parameter_names
56
from .powder import PowderCalibrator
67
from .structureless import StructurelessCalibrator
@@ -14,7 +15,7 @@
1415
'InstrumentCalibrator',
1516
'LaueCalibrator',
1617
'PowderCalibrator',
17-
'RelativeConstraints',
18+
'RelativeConstraintsType',
1819
'StructurelessCalibrator',
1920
'StructureLessCalibrator',
2021
]

hexrd/fitting/calibration/instrument.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from typing import Optional
23

34
import lmfit
45
import numpy as np
@@ -9,7 +10,11 @@
910
DEFAULT_EULER_CONVENTION,
1011
update_instrument_from_params,
1112
validate_params_list,
13+
)
14+
from .relative_constraints import (
15+
create_relative_constraints,
1216
RelativeConstraints,
17+
RelativeConstraintsType,
1318
)
1419

1520
logger = logging.getLogger()
@@ -24,7 +29,7 @@ class InstrumentCalibrator:
2429
def __init__(self, *args, engineering_constraints=None,
2530
set_refinements_from_instrument_flags=True,
2631
euler_convention=DEFAULT_EULER_CONVENTION,
27-
relative_constraints=RelativeConstraints.none):
32+
relative_constraints_type=RelativeConstraintsType.none):
2833
"""
2934
Model for instrument calibration class as a function of
3035
@@ -47,7 +52,8 @@ def __init__(self, *args, engineering_constraints=None,
4752
assert calib.instr is self.instr, \
4853
"all calibrators must refer to the same instrument"
4954
self._engineering_constraints = engineering_constraints
50-
self._relative_constraints = relative_constraints
55+
self._relative_constraints = create_relative_constraints(
56+
relative_constraints_type, self.instr)
5157
self.euler_convention = euler_convention
5258

5359
self.params = self.make_lmfit_params()
@@ -164,18 +170,32 @@ def engineering_constraints(self, v):
164170
self._engineering_constraints = v
165171
self.params = self.make_lmfit_params()
166172

173+
@property
174+
def relative_constraints_type(self):
175+
return self._relative_constraints.type
176+
177+
@relative_constraints_type.setter
178+
def relative_constraints_type(self, v: Optional[RelativeConstraintsType]):
179+
v = v if v is not None else RelativeConstraintsType.none
180+
181+
current = getattr(self, '_relative_constraints', None)
182+
if current is None or current.type != v:
183+
self.relative_constraints = create_relative_constraints(
184+
v, self.instr)
185+
167186
@property
168187
def relative_constraints(self) -> RelativeConstraints:
169188
return self._relative_constraints
170189

171190
@relative_constraints.setter
172191
def relative_constraints(self, v: RelativeConstraints):
173-
if v == self._relative_constraints:
174-
return
175-
176192
self._relative_constraints = v
177193
self.params = self.make_lmfit_params()
178194

195+
def reset_relative_constraint_params(self):
196+
# Set them back to zero.
197+
self.relative_constraints.reset()
198+
179199
def run_calibration(self, odict):
180200
resd0 = self.residual()
181201
nrm_ssr_0 = _normalized_ssqr(resd0)

hexrd/fitting/calibration/lmfit_param_handling.py

Lines changed: 71 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from enum import Enum
1+
from typing import Optional
22

33
import lmfit
44
import numpy as np
@@ -17,24 +17,18 @@
1717
rotMatOfExpMap,
1818
)
1919
from hexrd.material.unitcell import _lpname
20+
from .relative_constraints import (
21+
RelativeConstraints,
22+
RelativeConstraintsType,
23+
)
2024

2125

2226
# First is the axes_order, second is extrinsic
2327
DEFAULT_EULER_CONVENTION = ('zxz', False)
2428

2529

26-
class RelativeConstraints(Enum):
27-
"""These are relative constraints between the detectors"""
28-
# 'none' means no relative constraints
29-
none = 'None'
30-
# 'group' means constrain tilts/translations within a group
31-
group = 'Group'
32-
# 'system' means constrain tilts/translations within the whole system
33-
system = 'System'
34-
35-
3630
def create_instr_params(instr, euler_convention=DEFAULT_EULER_CONVENTION,
37-
relative_constraints=RelativeConstraints.none):
31+
relative_constraints=None):
3832
# add with tuples: (NAME VALUE VARY MIN MAX EXPR BRUTE_STEP)
3933
parms_list = []
4034

@@ -62,23 +56,27 @@ def create_instr_params(instr, euler_convention=DEFAULT_EULER_CONVENTION,
6256
parms_list.append(('instr_tvec_y', instr.tvec[1], False, -np.inf, np.inf))
6357
parms_list.append(('instr_tvec_z', instr.tvec[2], False, -np.inf, np.inf))
6458

65-
if relative_constraints == RelativeConstraints.none:
59+
if (
60+
relative_constraints is None or
61+
relative_constraints.type == RelativeConstraintsType.none
62+
):
6663
add_unconstrained_detector_parameters(
6764
instr,
6865
euler_convention,
6966
parms_list,
7067
)
71-
elif relative_constraints == RelativeConstraints.group:
68+
elif relative_constraints.type == RelativeConstraintsType.group:
7269
# This should be implemented soon
73-
raise NotImplementedError(relative_constraints)
74-
elif relative_constraints == RelativeConstraints.system:
70+
raise NotImplementedError(relative_constraints.type)
71+
elif relative_constraints.type == RelativeConstraintsType.system:
7572
add_system_constrained_detector_parameters(
7673
instr,
7774
euler_convention,
7875
parms_list,
76+
relative_constraints,
7977
)
8078
else:
81-
raise NotImplementedError(relative_constraints)
79+
raise NotImplementedError(relative_constraints.type)
8280

8381
return parms_list
8482

@@ -122,10 +120,24 @@ def add_unconstrained_detector_parameters(instr, euler_convention, parms_list):
122120
-np.inf, np.inf))
123121

124122

125-
def add_system_constrained_detector_parameters(instr, euler_convention,
126-
parms_list):
127-
mean_center = instr.mean_detector_center
128-
mean_tilt = instr.mean_detector_tilt
123+
def add_system_constrained_detector_parameters(
124+
instr, euler_convention,
125+
parms_list, relative_constraints: RelativeConstraints):
126+
system_params = relative_constraints.params
127+
system_tvec = system_params['translation']
128+
system_tilt = system_params['tilt']
129+
130+
if euler_convention is not None:
131+
# Convert the tilt to the specified Euler convention
132+
normalized = normalize_euler_convention(euler_convention)
133+
rme = RotMatEuler(
134+
np.zeros(3,),
135+
axes_order=normalized[0],
136+
extrinsic=normalized[1],
137+
)
138+
139+
rme.rmat = _tilt_to_rmat(system_tilt, None)
140+
system_tilt = np.degrees(rme.angles)
129141

130142
tvec_names = [
131143
'system_tvec_x',
@@ -138,12 +150,12 @@ def add_system_constrained_detector_parameters(instr, euler_convention,
138150
tilt_deltas = [2, 2, 2]
139151

140152
for i, name in enumerate(tvec_names):
141-
value = mean_center[i]
153+
value = system_tvec[i]
142154
delta = tvec_deltas[i]
143155
parms_list.append((name, value, True, value - delta, value + delta))
144156

145157
for i, name in enumerate(tilt_names):
146-
value = mean_tilt[i]
158+
value = system_tilt[i]
147159
delta = tilt_deltas[i]
148160
parms_list.append((name, value, True, value - delta, value + delta))
149161

@@ -160,8 +172,10 @@ def create_beam_param_names(instr: HEDMInstrument) -> dict[str, str]:
160172
return param_names
161173

162174

163-
def update_instrument_from_params(instr, params, euler_convention,
164-
relative_constraints):
175+
def update_instrument_from_params(
176+
instr, params,
177+
euler_convention=DEFAULT_EULER_CONVENTION,
178+
relative_constraints: Optional[RelativeConstraints] = None):
165179
"""
166180
this function updates the instrument from the
167181
lmfit parameter list. we don't have to keep track
@@ -196,23 +210,27 @@ def update_instrument_from_params(instr, params, euler_convention,
196210
params['instr_tvec_z'].value]
197211
instr.tvec = np.r_[instr_tvec]
198212

199-
if relative_constraints == RelativeConstraints.none:
213+
if (
214+
relative_constraints is None or
215+
relative_constraints.type == RelativeConstraintsType.none
216+
):
200217
update_unconstrained_detector_parameters(
201218
instr,
202219
params,
203220
euler_convention,
204221
)
205-
elif relative_constraints == RelativeConstraints.group:
222+
elif relative_constraints.type == RelativeConstraintsType.group:
206223
# This should be implemented soon
207-
raise NotImplementedError(relative_constraints)
208-
elif relative_constraints == RelativeConstraints.system:
224+
raise NotImplementedError(relative_constraints.type)
225+
elif relative_constraints.type == RelativeConstraintsType.system:
209226
update_system_constrained_detector_parameters(
210227
instr,
211228
params,
212229
euler_convention,
230+
relative_constraints,
213231
)
214232
else:
215-
raise NotImplementedError(relative_constraints)
233+
raise NotImplementedError(relative_constraints.type)
216234

217235

218236
def update_unconstrained_detector_parameters(instr, params, euler_convention):
@@ -245,10 +263,15 @@ def update_unconstrained_detector_parameters(instr, params, euler_convention):
245263
)
246264

247265

248-
def update_system_constrained_detector_parameters(instr, params, euler_convention):
249-
# We will always rotate/translate about the center of the group
266+
def update_system_constrained_detector_parameters(
267+
instr, params, euler_convention,
268+
relative_constraints: RelativeConstraints):
269+
# We will always rotate about the center of the detectors
250270
mean_center = instr.mean_detector_center
251-
mean_tilt = instr.mean_detector_tilt
271+
272+
system_params = relative_constraints.params
273+
system_tvec = system_params['translation']
274+
system_tilt = system_params['tilt']
252275

253276
tvec_names = [
254277
'system_tvec_x',
@@ -263,11 +286,11 @@ def update_system_constrained_detector_parameters(instr, params, euler_conventio
263286
if any(params[x].vary for x in tilt_names):
264287
# Find the change in tilt, create an rmat, then apply to detector tilts
265288
# and translations.
266-
new_mean_tilt = np.array([params[x].value for x in tilt_names])
289+
new_system_tilt = np.array([params[x].value for x in tilt_names])
267290

268-
# The old mean tilt was in the None convention
269-
old_rmat = _tilt_to_rmat(mean_tilt, None)
270-
new_rmat = _tilt_to_rmat(new_mean_tilt, euler_convention)
291+
# The old system tilt was in the None convention
292+
old_rmat = _tilt_to_rmat(system_tilt, None)
293+
new_rmat = _tilt_to_rmat(new_system_tilt, euler_convention)
271294

272295
# Compute the rmat used to convert from old to new
273296
rmat_diff = new_rmat @ old_rmat.T
@@ -276,19 +299,26 @@ def update_system_constrained_detector_parameters(instr, params, euler_conventio
276299
for panel in instr.detectors.values():
277300
panel.tilt = _rmat_to_tilt(rmat_diff @ panel.rmat)
278301

279-
# Also rotate the detectors about the center
302+
# Also rotate the detectors about the mean center
280303
panel.tvec = rmat_diff @ (panel.tvec - mean_center) + mean_center
281304

305+
# Update the system tilt
306+
system_tilt[:] = _rmat_to_tilt(new_rmat)
307+
282308
if any(params[x].vary for x in tvec_names):
283309
# Find the change in center and shift all tvecs
284-
new_mean_center = np.array([params[x].value for x in tvec_names])
310+
new_system_tvec = np.array([params[x].value for x in tvec_names])
285311

286-
diff = new_mean_center - mean_center
312+
diff = new_system_tvec - system_tvec
287313
for panel in instr.detectors.values():
288314
panel.tvec += diff
289315

316+
# Update the system tvec
317+
system_tvec[:] = new_system_tvec
318+
290319

291-
def _tilt_to_rmat(tilt: np.ndarray, euler_convention: dict | tuple) -> np.ndarray:
320+
def _tilt_to_rmat(tilt: np.ndarray,
321+
euler_convention: dict | tuple) -> np.ndarray:
292322
# Convert the tilt to exponential map parameters, and then
293323
# to the rotation matrix, and return.
294324
if euler_convention is None:

hexrd/fitting/calibration/structureless.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import copy
2+
from typing import Optional
3+
24
import lmfit
35
import numpy as np
46

@@ -9,10 +11,14 @@
911
create_instr_params,
1012
create_tth_parameters,
1113
DEFAULT_EULER_CONVENTION,
12-
RelativeConstraints,
1314
tth_parameter_prefixes,
1415
update_instrument_from_params,
1516
)
17+
from .relative_constraints import (
18+
create_relative_constraints,
19+
RelativeConstraints,
20+
RelativeConstraintsType,
21+
)
1622

1723

1824
class StructurelessCalibrator:
@@ -39,14 +45,15 @@ def __init__(self,
3945
data,
4046
tth_distortion=None,
4147
engineering_constraints=None,
42-
relative_constraints=RelativeConstraints.none,
48+
relative_constraints_type=RelativeConstraintsType.none,
4349
euler_convention=DEFAULT_EULER_CONVENTION):
4450

4551
self._instr = instr
4652
self._data = data
4753
self._tth_distortion = tth_distortion
4854
self._engineering_constraints = engineering_constraints
49-
self._relative_constraints = relative_constraints
55+
self._relative_constraints = create_relative_constraints(
56+
relative_constraints_type, self.instr)
5057
self.euler_convention = euler_convention
5158
self._update_tth_distortion_panels()
5259
self.make_lmfit_params()
@@ -163,16 +170,26 @@ def _update_tth_distortion_panels(self):
163170
obj.panel = self.instr.detectors[det_key]
164171

165172
@property
166-
def relative_constraints(self):
173+
def relative_constraints_type(self):
174+
return self._relative_constraints.type
175+
176+
@relative_constraints_type.setter
177+
def relative_constraints_type(self, v: Optional[RelativeConstraintsType]):
178+
v = v if v is not None else RelativeConstraintsType.none
179+
180+
current = getattr(self, '_relative_constraints', None)
181+
if current is None or current.type != v:
182+
self.relative_constraints = create_relative_constraints(
183+
v, self.instr)
184+
185+
@property
186+
def relative_constraints(self) -> RelativeConstraints:
167187
return self._relative_constraints
168188

169189
@relative_constraints.setter
170-
def relative_constraints(self, v):
171-
if v == self._relative_constraints:
172-
return
173-
190+
def relative_constraints(self, v: RelativeConstraints):
174191
self._relative_constraints = v
175-
self.make_lmfit_params()
192+
self.params = self.make_lmfit_params()
176193

177194
@property
178195
def engineering_constraints(self):

0 commit comments

Comments
 (0)