Skip to content

Commit 785431e

Browse files
authored
Merge pull request #845 from HEXRD/calcstar-speedup
Calcstar speedup
2 parents e0e7938 + 48cb459 commit 785431e

File tree

2 files changed

+84
-48
lines changed

2 files changed

+84
-48
lines changed

hexrd/core/material/unitcell.py

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,31 @@ def _calclength(u, mat):
3333

3434
@njit(cache=True, nogil=True)
3535
def _calcstar(v, sym, mat):
36-
vsym = np.atleast_2d(v)
37-
for s in sym:
36+
vsym = np.empty((sym.shape[0], v.shape[0]))
37+
nsym = sym.shape[0]
38+
39+
n = 0
40+
vsym[n, :] = v
41+
n = n + 1
42+
43+
# the first element is always the identity
44+
# so we can safely skip that
45+
for s in sym[1:, :, :]:
3846
vp = np.dot(np.ascontiguousarray(s), v)
47+
3948
# check if this is new
4049
isnew = True
41-
for vec in vsym:
42-
vv = vp - vec
43-
dist = _calclength(vv, mat)
44-
if dist < 1e-3:
50+
for vec in vsym[0:n, :]:
51+
dist = np.sum((vp - vec) ** 2)
52+
if dist < 1e-4:
4553
isnew = False
4654
break
55+
4756
if isnew:
48-
vp = np.atleast_2d(vp)
49-
vsym = np.vstack((vsym, vp))
57+
vsym[n, :] = vp
58+
n = n + 1
5059

51-
return vsym
60+
return vsym[0:n, :]
5261

5362

5463
class unitcell:
@@ -74,7 +83,6 @@ def __init__(
7483
beamenergy,
7584
sgsetting=0,
7685
):
77-
7886
self._tstart = time.time()
7987
self.pref = 0.4178214
8088

@@ -132,7 +140,6 @@ def CalcWavelength(self):
132140
self.wavelength *= 1e9
133141

134142
def calcBetaij(self):
135-
136143
self.betaij = np.zeros([3, 3, self.atom_ntype])
137144
for i in range(self.U.shape[0]):
138145
U = self.U[i, :]
@@ -143,7 +150,6 @@ def calcBetaij(self):
143150
self.betaij[:, :, i] *= 2.0 * np.pi**2 * self._aij
144151

145152
def calcmatrices(self):
146-
147153
a = self.a
148154
b = self.b
149155
c = self.c
@@ -266,7 +272,6 @@ def TransSpace(self, v_in, inspace, outspace):
266272
''' calculate dot product of two vectors in any space 'd' 'r' or 'c' '''
267273

268274
def CalcDot(self, u, v, space):
269-
270275
if space == 'd':
271276
dot = np.dot(u, np.dot(self.dmt, v))
272277
elif space == 'r':
@@ -279,7 +284,6 @@ def CalcDot(self, u, v, space):
279284
return dot
280285

281286
def CalcLength(self, u, space):
282-
283287
if space == 'd':
284288
mat = self.dmt
285289
# vlen = np.sqrt(np.dot(u, np.dot(self.dmt, u)))
@@ -304,7 +308,6 @@ def NormVec(self, u, space):
304308
''' calculate angle between two vectors in any space'''
305309

306310
def CalcAngle(self, u, v, space):
307-
308311
ulen = self.CalcLength(u, space)
309312
vlen = self.CalcLength(v, space)
310313

@@ -401,7 +404,6 @@ def CalcCross(self, p, q, inspace, outspace, vol_divide=False):
401404
return pxq
402405

403406
def GenerateRecipPGSym(self):
404-
405407
self.SYM_PG_r = self.SYM_PG_d[0, :, :]
406408
self.SYM_PG_r = np.broadcast_to(self.SYM_PG_r, [1, 3, 3])
407409

@@ -473,7 +475,6 @@ def GenerateCartesianPGSym(self):
473475
self.SYM_PG_supergroup_laue = sym_supergroup_laue
474476

475477
else:
476-
477478
self.SYM_PG_supergroup = []
478479
self.SYM_PG_supergroup_laue = []
479480

@@ -506,7 +507,6 @@ def GenerateCartesianPGSym(self):
506507
SS 12/10/2020
507508
'''
508509
if self.latticeType == 'monoclinic':
509-
510510
om = np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, -1.0, 0.0]])
511511

512512
for i, s in enumerate(self.SYM_PG_c):
@@ -616,7 +616,6 @@ def CalcPositions(self):
616616
asym_pos = []
617617

618618
for i in range(self.atom_ntype):
619-
620619
v = self.atom_pos[i, 0:3]
621620
apos, n = self.CalcOrbit(v)
622621

@@ -784,7 +783,6 @@ def CalcMaxGIndex(self):
784783
self.il = self.il + 1
785784

786785
def InitializeInterpTable(self):
787-
788786
f_anomalous_data = []
789787
self.pe_cs = {}
790788
data = (
@@ -794,7 +792,6 @@ def InitializeInterpTable(self):
794792
)
795793
with h5py.File(data, 'r') as fid:
796794
for i in range(0, self.atom_ntype):
797-
798795
Z = self.atom_type[i]
799796
elem = constants.ptableinverse[Z]
800797

@@ -926,7 +923,6 @@ def calc_number_density(self):
926923
return 1e-12 * self.density * Na / M
927924

928925
def calc_absorption_cross_sec(self):
929-
930926
abs_cs_total = 0.0
931927
for i in range(self.atom_ntype):
932928
Z = self.atom_type[i]
@@ -983,8 +979,7 @@ def ChooseSymmetric(self, hkllist, InversionSymmetry=True):
983979

984980
for i, g in enumerate(hkllist):
985981
if mask[i]:
986-
987-
geqv = self.CalcStar(g, 'r', applyLaue=laue)
982+
geqv = self.CalcStar(g, 'r', applyLaue=laue).astype(int)
988983

989984
for r in geqv[1:,]:
990985
rid = np.where(np.all(r == hkllist, axis=1))
@@ -1069,10 +1064,8 @@ def getHKLs(self, dmin):
10691064
hkl_dsp = []
10701065

10711066
for g in hkl_allowed:
1072-
10731067
# ignore [0 0 0] as it is the direct beam
10741068
if np.sum(np.abs(g)) != 0:
1075-
10761069
dspace = 1.0 / self.CalcLength(g, 'r')
10771070

10781071
if dspace >= dmin:
@@ -1125,7 +1118,6 @@ def MakeStiffnessMatrix(self, inp_Cvals):
11251118
# initialize all zeros and fill the supplied values
11261119
C = np.zeros([6, 6])
11271120
for i, x in enumerate(_StiffnessDict[self._laueGroup][0]):
1128-
11291121
C[x] = inp_Cvals[i]
11301122

11311123
# enforce the equality constraints
@@ -1176,7 +1168,6 @@ def inside_spheretriangle(self, conn, dir3, hemisphere, switch):
11761168

11771169
mask = []
11781170
for x in dir3:
1179-
11801171
x2 = np.atleast_2d(x).T
11811172
d1 = np.linalg.det(np.hstack((A, B, x2)))
11821173
d2 = np.linalg.det(np.hstack((A, x2, C)))
@@ -1293,9 +1284,7 @@ def reduce_dirvector(self, dir3, switch='pg'):
12931284
sym = self.SYM_PG_supergroup_laue
12941285

12951286
for sop in sym:
1296-
12971287
if dir3_copy.size != 0:
1298-
12991288
dir3_sym = np.dot(sop, dir3_copy.T).T
13001289

13011290
mask = np.zeros(dir3_sym.shape[0]).astype(bool)
@@ -1783,7 +1772,6 @@ def vol_per_atom(self):
17831772
# vol per atom in A^3
17841773
return 1e3 * self.vol / self.num_atom
17851774

1786-
17871775
@property
17881776
def chemical_formula(self):
17891777
chemical_formula = ''
@@ -1797,7 +1785,7 @@ def chemical_formula(self):
17971785
elem = constants.ptableinverse[Z]
17981786
numat = self.numat[i]
17991787
occ = self.atom_pos[i, 3]
1800-
abundance = str(numat*occ)
1788+
abundance = str(numat * occ)
18011789
if abundance.endswith('.0'):
18021790
# We can remove the trailing decimal and zero.
18031791
# This looks nicer if you print the formula,

tests/unitcell/test_vec_math.py

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -85,19 +85,67 @@ def test_trans_space(cell: unitcell.unitcell):
8585

8686

8787
def test_calc_star(cell: unitcell.unitcell):
88-
"""
89-
Just ensuring that the outspace doesn't matter
90-
"""
91-
np.random.seed(0)
92-
for _ in range(100):
93-
v1 = np.random.rand(3) * 10 - 5
94-
space = np.random.choice(['d', 'r'])
95-
v1c = cell.TransSpace(v1, space, 'c')
96-
assert np.allclose(
97-
cell.CalcStar(v1, space, False),
98-
cell.TransSpace(cell.CalcStar(v1c, 'c', False), 'c', space),
99-
)
100-
assert np.allclose(
101-
cell.CalcStar(v1, space, True),
102-
cell.TransSpace(cell.CalcStar(v1c, 'c', True), 'c', space),
103-
)
88+
v = [1.0, 2.0, 3.0]
89+
vsym = cell.CalcStar(v, 'd', False)
90+
assert vsym.shape[0] == 48
91+
92+
vsym = cell.CalcStar(v, 'r', False)
93+
assert vsym.shape[0] == 48
94+
95+
vsym = cell.CalcStar(v, 'd', True)
96+
assert vsym.shape[0] == 48
97+
98+
vsym = cell.CalcStar(v, 'r', True)
99+
assert vsym.shape[0] == 48
100+
101+
v = [1.0, 1.0, 3.0]
102+
vsym = cell.CalcStar(v, 'd', False)
103+
assert vsym.shape[0] == 24
104+
105+
vsym = cell.CalcStar(v, 'r', False)
106+
assert vsym.shape[0] == 24
107+
108+
vsym = cell.CalcStar(v, 'd', True)
109+
assert vsym.shape[0] == 24
110+
111+
vsym = cell.CalcStar(v, 'r', True)
112+
assert vsym.shape[0] == 24
113+
114+
v = [1.0, 1.0, 0.0]
115+
vsym = cell.CalcStar(v, 'd', False)
116+
assert vsym.shape[0] == 12
117+
118+
vsym = cell.CalcStar(v, 'r', False)
119+
assert vsym.shape[0] == 12
120+
121+
vsym = cell.CalcStar(v, 'd', True)
122+
assert vsym.shape[0] == 12
123+
124+
vsym = cell.CalcStar(v, 'r', True)
125+
assert vsym.shape[0] == 12
126+
127+
v = [1.0, 1.0, 1.0]
128+
vsym = cell.CalcStar(v, 'd', False)
129+
assert vsym.shape[0] == 8
130+
131+
vsym = cell.CalcStar(v, 'r', False)
132+
assert vsym.shape[0] == 8
133+
134+
vsym = cell.CalcStar(v, 'd', True)
135+
assert vsym.shape[0] == 8
136+
137+
vsym = cell.CalcStar(v, 'r', True)
138+
assert vsym.shape[0] == 8
139+
140+
v = [1.0, 0.0, 0.0]
141+
vsym = cell.CalcStar(v, 'd', False)
142+
assert vsym.shape[0] == 6
143+
144+
vsym = cell.CalcStar(v, 'r', False)
145+
assert vsym.shape[0] == 6
146+
147+
vsym = cell.CalcStar(v, 'd', True)
148+
assert vsym.shape[0] == 6
149+
150+
vsym = cell.CalcStar(v, 'r', True)
151+
assert vsym.shape[0] == 6

0 commit comments

Comments
 (0)