Skip to content

Commit ad7f741

Browse files
committed
Clean up relative methods into classes
The classes keep track of the current values of the relative parameters (before they were modified by lmfit). This is necessary for modifying all detectors by the diff of the change. Signed-off-by: Patrick Avery <[email protected]>
1 parent 090402d commit ad7f741

File tree

9 files changed

+245
-99
lines changed

9 files changed

+245
-99
lines changed

hexrd/fitting/calibration/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from .instrument import InstrumentCalibrator
22
from .laue import LaueCalibrator
3-
from .lmfit_param_handling import RelativeConstraints
43
from .multigrain import calibrate_instrument_from_sx, generate_parameter_names
54
from .powder import PowderCalibrator
65
from .structureless import StructurelessCalibrator
@@ -14,7 +13,6 @@
1413
'InstrumentCalibrator',
1514
'LaueCalibrator',
1615
'PowderCalibrator',
17-
'RelativeConstraints',
1816
'StructurelessCalibrator',
1917
'StructureLessCalibrator',
2018
]

hexrd/fitting/calibration/instrument.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from typing import Optional
23

34
import lmfit
45
import numpy as np
@@ -9,7 +10,11 @@
910
DEFAULT_EULER_CONVENTION,
1011
update_instrument_from_params,
1112
validate_params_list,
13+
)
14+
from .relative_constraints import (
15+
create_relative_constraints,
1216
RelativeConstraints,
17+
RelativeConstraintsType,
1318
)
1419

1520
logger = logging.getLogger()
@@ -24,7 +29,7 @@ class InstrumentCalibrator:
2429
def __init__(self, *args, engineering_constraints=None,
2530
set_refinements_from_instrument_flags=True,
2631
euler_convention=DEFAULT_EULER_CONVENTION,
27-
relative_constraints=RelativeConstraints.none):
32+
relative_constraints_type=RelativeConstraintsType.none):
2833
"""
2934
Model for instrument calibration class as a function of
3035
@@ -47,7 +52,8 @@ def __init__(self, *args, engineering_constraints=None,
4752
assert calib.instr is self.instr, \
4853
"all calibrators must refer to the same instrument"
4954
self._engineering_constraints = engineering_constraints
50-
self._relative_constraints = relative_constraints
55+
self._relative_constraints = create_relative_constraints(
56+
relative_constraints_type, self.instr)
5157
self.euler_convention = euler_convention
5258

5359
self.params = self.make_lmfit_params()
@@ -164,18 +170,32 @@ def engineering_constraints(self, v):
164170
self._engineering_constraints = v
165171
self.params = self.make_lmfit_params()
166172

173+
@property
174+
def relative_constraints_type(self):
175+
return self._relative_constraints.type
176+
177+
@relative_constraints_type.setter
178+
def relative_constraints_type(self, v: Optional[RelativeConstraintsType]):
179+
v = v if v is not None else RelativeConstraintsType.none
180+
181+
current = getattr(self, '_relative_constraints', None)
182+
if current is None or current.type != v:
183+
self.relative_constraints = create_relative_constraints(
184+
v, self.instr)
185+
167186
@property
168187
def relative_constraints(self) -> RelativeConstraints:
169188
return self._relative_constraints
170189

171190
@relative_constraints.setter
172191
def relative_constraints(self, v: RelativeConstraints):
173-
if v == self._relative_constraints:
174-
return
175-
176192
self._relative_constraints = v
177193
self.params = self.make_lmfit_params()
178194

195+
def reset_relative_constraint_params(self):
196+
# Set them back to zero.
197+
self.relative_constraints.reset()
198+
179199
def run_calibration(self, odict):
180200
resd0 = self.residual()
181201
nrm_ssr_0 = _normalized_ssqr(resd0)

hexrd/fitting/calibration/lmfit_param_handling.py

Lines changed: 71 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from enum import Enum
1+
from typing import Optional
22

33
import lmfit
44
import numpy as np
@@ -17,24 +17,18 @@
1717
rotMatOfExpMap,
1818
)
1919
from hexrd.material.unitcell import _lpname
20+
from .relative_constraints import (
21+
RelativeConstraints,
22+
RelativeConstraintsType,
23+
)
2024

2125

2226
# First is the axes_order, second is extrinsic
2327
DEFAULT_EULER_CONVENTION = ('zxz', False)
2428

2529

26-
class RelativeConstraints(Enum):
27-
"""These are relative constraints between the detectors"""
28-
# 'none' means no relative constraints
29-
none = 'None'
30-
# 'group' means constrain tilts/translations within a group
31-
group = 'Group'
32-
# 'system' means constrain tilts/translations within the whole system
33-
system = 'System'
34-
35-
3630
def create_instr_params(instr, euler_convention=DEFAULT_EULER_CONVENTION,
37-
relative_constraints=RelativeConstraints.none):
31+
relative_constraints=None):
3832
# add with tuples: (NAME VALUE VARY MIN MAX EXPR BRUTE_STEP)
3933
parms_list = []
4034

@@ -62,23 +56,27 @@ def create_instr_params(instr, euler_convention=DEFAULT_EULER_CONVENTION,
6256
parms_list.append(('instr_tvec_y', instr.tvec[1], False, -np.inf, np.inf))
6357
parms_list.append(('instr_tvec_z', instr.tvec[2], False, -np.inf, np.inf))
6458

65-
if relative_constraints == RelativeConstraints.none:
59+
if (
60+
relative_constraints is None or
61+
relative_constraints.type == RelativeConstraintsType.none
62+
):
6663
add_unconstrained_detector_parameters(
6764
instr,
6865
euler_convention,
6966
parms_list,
7067
)
71-
elif relative_constraints == RelativeConstraints.group:
68+
elif relative_constraints.type == RelativeConstraintsType.group:
7269
# This should be implemented soon
73-
raise NotImplementedError(relative_constraints)
74-
elif relative_constraints == RelativeConstraints.system:
70+
raise NotImplementedError(relative_constraints.type)
71+
elif relative_constraints.type == RelativeConstraintsType.system:
7572
add_system_constrained_detector_parameters(
7673
instr,
7774
euler_convention,
7875
parms_list,
76+
relative_constraints,
7977
)
8078
else:
81-
raise NotImplementedError(relative_constraints)
79+
raise NotImplementedError(relative_constraints.type)
8280

8381
return parms_list
8482

@@ -122,10 +120,24 @@ def add_unconstrained_detector_parameters(instr, euler_convention, parms_list):
122120
-np.inf, np.inf))
123121

124122

125-
def add_system_constrained_detector_parameters(instr, euler_convention,
126-
parms_list):
127-
mean_center = instr.mean_detector_center
128-
mean_tilt = instr.mean_detector_tilt
123+
def add_system_constrained_detector_parameters(
124+
instr, euler_convention,
125+
parms_list, relative_constraints: RelativeConstraints):
126+
system_params = relative_constraints.params
127+
system_tvec = system_params['translation']
128+
system_tilt = system_params['tilt']
129+
130+
if euler_convention is not None:
131+
# Convert the tilt to the specified Euler convention
132+
normalized = normalize_euler_convention(euler_convention)
133+
rme = RotMatEuler(
134+
np.zeros(3,),
135+
axes_order=normalized[0],
136+
extrinsic=normalized[1],
137+
)
138+
139+
rme.rmat = _tilt_to_rmat(system_tilt, None)
140+
system_tilt = np.degrees(rme.angles)
129141

130142
tvec_names = [
131143
'system_tvec_x',
@@ -138,12 +150,12 @@ def add_system_constrained_detector_parameters(instr, euler_convention,
138150
tilt_deltas = [2, 2, 2]
139151

140152
for i, name in enumerate(tvec_names):
141-
value = mean_center[i]
153+
value = system_tvec[i]
142154
delta = tvec_deltas[i]
143155
parms_list.append((name, value, True, value - delta, value + delta))
144156

145157
for i, name in enumerate(tilt_names):
146-
value = mean_tilt[i]
158+
value = system_tilt[i]
147159
delta = tilt_deltas[i]
148160
parms_list.append((name, value, True, value - delta, value + delta))
149161

@@ -160,8 +172,10 @@ def create_beam_param_names(instr: HEDMInstrument) -> dict[str, str]:
160172
return param_names
161173

162174

163-
def update_instrument_from_params(instr, params, euler_convention,
164-
relative_constraints):
175+
def update_instrument_from_params(
176+
instr, params,
177+
euler_convention=DEFAULT_EULER_CONVENTION,
178+
relative_constraints: Optional[RelativeConstraints] = None):
165179
"""
166180
this function updates the instrument from the
167181
lmfit parameter list. we don't have to keep track
@@ -196,23 +210,27 @@ def update_instrument_from_params(instr, params, euler_convention,
196210
params['instr_tvec_z'].value]
197211
instr.tvec = np.r_[instr_tvec]
198212

199-
if relative_constraints == RelativeConstraints.none:
213+
if (
214+
relative_constraints is None or
215+
relative_constraints.type == RelativeConstraintsType.none
216+
):
200217
update_unconstrained_detector_parameters(
201218
instr,
202219
params,
203220
euler_convention,
204221
)
205-
elif relative_constraints == RelativeConstraints.group:
222+
elif relative_constraints.type == RelativeConstraintsType.group:
206223
# This should be implemented soon
207-
raise NotImplementedError(relative_constraints)
208-
elif relative_constraints == RelativeConstraints.system:
224+
raise NotImplementedError(relative_constraints.type)
225+
elif relative_constraints.type == RelativeConstraintsType.system:
209226
update_system_constrained_detector_parameters(
210227
instr,
211228
params,
212229
euler_convention,
230+
relative_constraints,
213231
)
214232
else:
215-
raise NotImplementedError(relative_constraints)
233+
raise NotImplementedError(relative_constraints.type)
216234

217235

218236
def update_unconstrained_detector_parameters(instr, params, euler_convention):
@@ -245,10 +263,15 @@ def update_unconstrained_detector_parameters(instr, params, euler_convention):
245263
)
246264

247265

248-
def update_system_constrained_detector_parameters(instr, params, euler_convention):
249-
# We will always rotate/translate about the center of the group
266+
def update_system_constrained_detector_parameters(
267+
instr, params, euler_convention,
268+
relative_constraints: RelativeConstraints):
269+
# We will always rotate about the center of the detectors
250270
mean_center = instr.mean_detector_center
251-
mean_tilt = instr.mean_detector_tilt
271+
272+
system_params = relative_constraints.params
273+
system_tvec = system_params['translation']
274+
system_tilt = system_params['tilt']
252275

253276
tvec_names = [
254277
'system_tvec_x',
@@ -263,11 +286,11 @@ def update_system_constrained_detector_parameters(instr, params, euler_conventio
263286
if any(params[x].vary for x in tilt_names):
264287
# Find the change in tilt, create an rmat, then apply to detector tilts
265288
# and translations.
266-
new_mean_tilt = np.array([params[x].value for x in tilt_names])
289+
new_system_tilt = np.array([params[x].value for x in tilt_names])
267290

268-
# The old mean tilt was in the None convention
269-
old_rmat = _tilt_to_rmat(mean_tilt, None)
270-
new_rmat = _tilt_to_rmat(new_mean_tilt, euler_convention)
291+
# The old system tilt was in the None convention
292+
old_rmat = _tilt_to_rmat(system_tilt, None)
293+
new_rmat = _tilt_to_rmat(new_system_tilt, euler_convention)
271294

272295
# Compute the rmat used to convert from old to new
273296
rmat_diff = new_rmat @ old_rmat.T
@@ -276,19 +299,26 @@ def update_system_constrained_detector_parameters(instr, params, euler_conventio
276299
for panel in instr.detectors.values():
277300
panel.tilt = _rmat_to_tilt(rmat_diff @ panel.rmat)
278301

279-
# Also rotate the detectors about the center
302+
# Also rotate the detectors about the mean center
280303
panel.tvec = rmat_diff @ (panel.tvec - mean_center) + mean_center
281304

305+
# Update the system tilt
306+
system_tilt[:] = _rmat_to_tilt(new_rmat)
307+
282308
if any(params[x].vary for x in tvec_names):
283309
# Find the change in center and shift all tvecs
284-
new_mean_center = np.array([params[x].value for x in tvec_names])
310+
new_system_tvec = np.array([params[x].value for x in tvec_names])
285311

286-
diff = new_mean_center - mean_center
312+
diff = new_system_tvec - system_tvec
287313
for panel in instr.detectors.values():
288314
panel.tvec += diff
289315

316+
# Update the system tvec
317+
system_tvec[:] = new_system_tvec
318+
290319

291-
def _tilt_to_rmat(tilt: np.ndarray, euler_convention: dict | tuple) -> np.ndarray:
320+
def _tilt_to_rmat(tilt: np.ndarray,
321+
euler_convention: dict | tuple) -> np.ndarray:
292322
# Convert the tilt to exponential map parameters, and then
293323
# to the rotation matrix, and return.
294324
if euler_convention is None:

0 commit comments

Comments
 (0)