1- from enum import Enum
1+ from typing import Optional
22
33import lmfit
44import numpy as np
1717 rotMatOfExpMap ,
1818)
1919from hexrd .material .unitcell import _lpname
20+ from .relative_constraints import (
21+ RelativeConstraints ,
22+ RelativeConstraintsType ,
23+ )
2024
2125
2226# First is the axes_order, second is extrinsic
2327DEFAULT_EULER_CONVENTION = ('zxz' , False )
2428
2529
26- class RelativeConstraints (Enum ):
27- """These are relative constraints between the detectors"""
28- # 'none' means no relative constraints
29- none = 'None'
30- # 'group' means constrain tilts/translations within a group
31- group = 'Group'
32- # 'system' means constrain tilts/translations within the whole system
33- system = 'System'
34-
35-
3630def create_instr_params (instr , euler_convention = DEFAULT_EULER_CONVENTION ,
37- relative_constraints = RelativeConstraints . none ):
31+ relative_constraints = None ):
3832 # add with tuples: (NAME VALUE VARY MIN MAX EXPR BRUTE_STEP)
3933 parms_list = []
4034
@@ -62,23 +56,27 @@ def create_instr_params(instr, euler_convention=DEFAULT_EULER_CONVENTION,
6256 parms_list .append (('instr_tvec_y' , instr .tvec [1 ], False , - np .inf , np .inf ))
6357 parms_list .append (('instr_tvec_z' , instr .tvec [2 ], False , - np .inf , np .inf ))
6458
65- if relative_constraints == RelativeConstraints .none :
59+ if (
60+ relative_constraints is None or
61+ relative_constraints .type == RelativeConstraintsType .none
62+ ):
6663 add_unconstrained_detector_parameters (
6764 instr ,
6865 euler_convention ,
6966 parms_list ,
7067 )
71- elif relative_constraints == RelativeConstraints .group :
68+ elif relative_constraints . type == RelativeConstraintsType .group :
7269 # This should be implemented soon
73- raise NotImplementedError (relative_constraints )
74- elif relative_constraints == RelativeConstraints .system :
70+ raise NotImplementedError (relative_constraints . type )
71+ elif relative_constraints . type == RelativeConstraintsType .system :
7572 add_system_constrained_detector_parameters (
7673 instr ,
7774 euler_convention ,
7875 parms_list ,
76+ relative_constraints ,
7977 )
8078 else :
81- raise NotImplementedError (relative_constraints )
79+ raise NotImplementedError (relative_constraints . type )
8280
8381 return parms_list
8482
@@ -122,10 +120,24 @@ def add_unconstrained_detector_parameters(instr, euler_convention, parms_list):
122120 - np .inf , np .inf ))
123121
124122
125- def add_system_constrained_detector_parameters (instr , euler_convention ,
126- parms_list ):
127- mean_center = instr .mean_detector_center
128- mean_tilt = instr .mean_detector_tilt
123+ def add_system_constrained_detector_parameters (
124+ instr , euler_convention ,
125+ parms_list , relative_constraints : RelativeConstraints ):
126+ system_params = relative_constraints .params
127+ system_tvec = system_params ['translation' ]
128+ system_tilt = system_params ['tilt' ]
129+
130+ if euler_convention is not None :
131+ # Convert the tilt to the specified Euler convention
132+ normalized = normalize_euler_convention (euler_convention )
133+ rme = RotMatEuler (
134+ np .zeros (3 ,),
135+ axes_order = normalized [0 ],
136+ extrinsic = normalized [1 ],
137+ )
138+
139+ rme .rmat = _tilt_to_rmat (system_tilt , None )
140+ system_tilt = np .degrees (rme .angles )
129141
130142 tvec_names = [
131143 'system_tvec_x' ,
@@ -138,12 +150,12 @@ def add_system_constrained_detector_parameters(instr, euler_convention,
138150 tilt_deltas = [2 , 2 , 2 ]
139151
140152 for i , name in enumerate (tvec_names ):
141- value = mean_center [i ]
153+ value = system_tvec [i ]
142154 delta = tvec_deltas [i ]
143155 parms_list .append ((name , value , True , value - delta , value + delta ))
144156
145157 for i , name in enumerate (tilt_names ):
146- value = mean_tilt [i ]
158+ value = system_tilt [i ]
147159 delta = tilt_deltas [i ]
148160 parms_list .append ((name , value , True , value - delta , value + delta ))
149161
@@ -160,8 +172,10 @@ def create_beam_param_names(instr: HEDMInstrument) -> dict[str, str]:
160172 return param_names
161173
162174
163- def update_instrument_from_params (instr , params , euler_convention ,
164- relative_constraints ):
175+ def update_instrument_from_params (
176+ instr , params ,
177+ euler_convention = DEFAULT_EULER_CONVENTION ,
178+ relative_constraints : Optional [RelativeConstraints ] = None ):
165179 """
166180 this function updates the instrument from the
167181 lmfit parameter list. we don't have to keep track
@@ -196,23 +210,27 @@ def update_instrument_from_params(instr, params, euler_convention,
196210 params ['instr_tvec_z' ].value ]
197211 instr .tvec = np .r_ [instr_tvec ]
198212
199- if relative_constraints == RelativeConstraints .none :
213+ if (
214+ relative_constraints is None or
215+ relative_constraints .type == RelativeConstraintsType .none
216+ ):
200217 update_unconstrained_detector_parameters (
201218 instr ,
202219 params ,
203220 euler_convention ,
204221 )
205- elif relative_constraints == RelativeConstraints .group :
222+ elif relative_constraints . type == RelativeConstraintsType .group :
206223 # This should be implemented soon
207- raise NotImplementedError (relative_constraints )
208- elif relative_constraints == RelativeConstraints .system :
224+ raise NotImplementedError (relative_constraints . type )
225+ elif relative_constraints . type == RelativeConstraintsType .system :
209226 update_system_constrained_detector_parameters (
210227 instr ,
211228 params ,
212229 euler_convention ,
230+ relative_constraints ,
213231 )
214232 else :
215- raise NotImplementedError (relative_constraints )
233+ raise NotImplementedError (relative_constraints . type )
216234
217235
218236def update_unconstrained_detector_parameters (instr , params , euler_convention ):
@@ -245,10 +263,15 @@ def update_unconstrained_detector_parameters(instr, params, euler_convention):
245263 )
246264
247265
248- def update_system_constrained_detector_parameters (instr , params , euler_convention ):
249- # We will always rotate/translate about the center of the group
266+ def update_system_constrained_detector_parameters (
267+ instr , params , euler_convention ,
268+ relative_constraints : RelativeConstraints ):
269+ # We will always rotate about the center of the detectors
250270 mean_center = instr .mean_detector_center
251- mean_tilt = instr .mean_detector_tilt
271+
272+ system_params = relative_constraints .params
273+ system_tvec = system_params ['translation' ]
274+ system_tilt = system_params ['tilt' ]
252275
253276 tvec_names = [
254277 'system_tvec_x' ,
@@ -263,11 +286,11 @@ def update_system_constrained_detector_parameters(instr, params, euler_conventio
263286 if any (params [x ].vary for x in tilt_names ):
264287 # Find the change in tilt, create an rmat, then apply to detector tilts
265288 # and translations.
266- new_mean_tilt = np .array ([params [x ].value for x in tilt_names ])
289+ new_system_tilt = np .array ([params [x ].value for x in tilt_names ])
267290
268- # The old mean tilt was in the None convention
269- old_rmat = _tilt_to_rmat (mean_tilt , None )
270- new_rmat = _tilt_to_rmat (new_mean_tilt , euler_convention )
291+ # The old system tilt was in the None convention
292+ old_rmat = _tilt_to_rmat (system_tilt , None )
293+ new_rmat = _tilt_to_rmat (new_system_tilt , euler_convention )
271294
272295 # Compute the rmat used to convert from old to new
273296 rmat_diff = new_rmat @ old_rmat .T
@@ -276,19 +299,26 @@ def update_system_constrained_detector_parameters(instr, params, euler_conventio
276299 for panel in instr .detectors .values ():
277300 panel .tilt = _rmat_to_tilt (rmat_diff @ panel .rmat )
278301
279- # Also rotate the detectors about the center
302+ # Also rotate the detectors about the mean center
280303 panel .tvec = rmat_diff @ (panel .tvec - mean_center ) + mean_center
281304
305+ # Update the system tilt
306+ system_tilt [:] = _rmat_to_tilt (new_rmat )
307+
282308 if any (params [x ].vary for x in tvec_names ):
283309 # Find the change in center and shift all tvecs
284- new_mean_center = np .array ([params [x ].value for x in tvec_names ])
310+ new_system_tvec = np .array ([params [x ].value for x in tvec_names ])
285311
286- diff = new_mean_center - mean_center
312+ diff = new_system_tvec - system_tvec
287313 for panel in instr .detectors .values ():
288314 panel .tvec += diff
289315
316+ # Update the system tvec
317+ system_tvec [:] = new_system_tvec
318+
290319
291- def _tilt_to_rmat (tilt : np .ndarray , euler_convention : dict | tuple ) -> np .ndarray :
320+ def _tilt_to_rmat (tilt : np .ndarray ,
321+ euler_convention : dict | tuple ) -> np .ndarray :
292322 # Convert the tilt to exponential map parameters, and then
293323 # to the rotation matrix, and return.
294324 if euler_convention is None :
0 commit comments