Skip to content

Commit 18731f9

Browse files
authored
Merge pull request #51 from sivonxay/main
Miscellaneous Changes
2 parents e2fa19c + 1eb5d12 commit 18731f9

File tree

8 files changed

+155
-126
lines changed

8 files changed

+155
-126
lines changed

src/NanoParticleTools/builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,14 @@ def __init__(self,
5353
**kwargs)
5454
self.connect()
5555

56-
def get_grouped_docs(self) -> List[Dict]:
56+
def get_grouped_docs(self, additional_keys: List = None) -> List[Dict]:
5757
group_keys = [
5858
"data.n_dopants", "data.n_dopant_sites", "data.formula",
5959
"data.nanostructure_size", "data.formula_by_constraint",
6060
"data.excitation_power", "data.excitation_wavelength"
6161
]
62+
if additional_keys is not None and isinstance(additional_keys, list):
63+
group_keys += additional_keys
6264
return self.source.groupby(keys=group_keys,
6365
criteria=self.docs_filter,
6466
properties=["_id"])

src/NanoParticleTools/machine_learning/data/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
import os
66

77
SUNSET_SPECIES_TABLE = {
8-
1: ["Yb", "Er", "Mg"],
9-
2: ["Yb", "Er"],
10-
3: ["Yb", "Er", "Mg", "Tm"],
11-
4: ["Yb", "Er"],
12-
5: ["Yb", "Er", "Nd"]
8+
1: sorted(["Yb", "Er", "Xsurfacesix"]),
9+
2: sorted(["Yb", "Er"]),
10+
3: sorted(["Yb", "Er", "Xsurfacesix", "Tm"]),
11+
4: sorted(["Yb", "Er"]),
12+
5: sorted(["Yb", "Er", "Nd"]),
13+
6: sorted(['Yb', 'Er', "Xsurfacesix", 'Tm', 'Nd', 'Ho', 'Eu', 'Sm', 'Dy'])
1314
}
1415

1516

src/NanoParticleTools/machine_learning/modules/ensemble.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def ensemble_forward(self, data: Data,
2828
output.append(y_hat)
2929

3030
x = torch.cat(output, dim=-1)
31-
return {'y': x, 'y_hat': x.mean(-1), 'std': x.std()}
31+
return {'y': x, 'y_hat': x.mean(-1), 'std': x.std(-1)}
3232

3333
def evaluate_step(self, data: Data) -> tuple[torch.Tensor, torch.Tensor]:
3434
output = []
@@ -52,6 +52,6 @@ def predict_step(
5252

5353
x = torch.cat(output, dim=-1)
5454
if return_stats:
55-
return {'y': x, 'y_hat': x.mean(-1), 'std': x.std()}
55+
return {'y': x, 'y_hat': x.mean(-1), 'std': x.std(-1)}
5656
else:
5757
return x.mean(-1)

src/NanoParticleTools/optimization/callbacks.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from NanoParticleTools.util.visualization import plot_nanoparticle_from_arrays
1+
from NanoParticleTools.util.visualization import plot_nanoparticle
22
from NanoParticleTools.machine_learning.data import FeatureProcessor
33

44
from maggma.stores import Store
@@ -11,7 +11,8 @@
1111
from uuid import uuid4
1212

1313

14-
def get_plotting_fn(feature_processor: FeatureProcessor) -> Callable:
14+
def get_plotting_fn(feature_processor: FeatureProcessor,
15+
as_np_array: bool = False) -> Callable:
1516
n_elements = len(feature_processor.possible_elements)
1617

1718
def plotting_fn(x, f=None, accept=None):
@@ -20,19 +21,30 @@ def plotting_fn(x, f=None, accept=None):
2021

2122
plt.figure()
2223
n_constraints = len(x) // (n_elements + 1)
23-
plot_nanoparticle_from_arrays(
24+
fig = plot_nanoparticle(
2425
np.concatenate(([0], x[-n_constraints:])),
2526
x[:-n_constraints].reshape(n_constraints, -1),
26-
dpi=80,
27-
elements=feature_processor.possible_elements,
28-
)
27+
dpi=300,
28+
elements=feature_processor.possible_elements)
2929
if f is not None:
3030
plt.text(0.1,
3131
0.95,
3232
f'UV Intensity={np.power(10, -f)-100:.2f}',
3333
fontsize=20,
3434
transform=plt.gca().transAxes)
35-
return plt
35+
if as_np_array:
36+
# If we haven't already shown or saved the plot, then we need to
37+
# draw the figure first.
38+
fig.canvas.draw()
39+
# Now we can save it to a numpy array.
40+
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
41+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
42+
43+
# Close the figure to remove it from the buffer
44+
plt.close(fig)
45+
return data
46+
else:
47+
return fig
3648

3749
return plotting_fn
3850

src/NanoParticleTools/optimization/scipy_optimize.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from collections.abc import Callable
1313

1414

15-
def get_bounds(n_constraints: int, n_elements: int) -> Bounds:
15+
def get_bounds(n_constraints: int, n_elements: int, **kwargs) -> Bounds:
1616
r"""
1717
Get the Bounds which are utilized by scipy minimize.
1818
@@ -34,7 +34,8 @@ def get_bounds(n_constraints: int, n_elements: int) -> Bounds:
3434
(np.zeros(num_dopant_nodes), np.zeros(n_constraints)))
3535
max_bounds = np.concatenate(
3636
(np.ones(num_dopant_nodes), np.ones(n_constraints)))
37-
bounds = Bounds(min_bounds, max_bounds)
37+
min_bounds[-1] = 1
38+
bounds = Bounds(min_bounds, max_bounds, **kwargs)
3839
return bounds
3940

4041

src/NanoParticleTools/species_data/species.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,14 @@ class Dopant(MSONable):
8484

8585
# The naming convention should always start with an X, since the symbol
8686
# cannot start with an existing element's symbol
87+
LEGACY_SURFACE_NAMES = {
88+
'Na': 'Surface',
89+
'Al': 'Surface3',
90+
'Si': 'Surface4',
91+
'P': 'Surface5',
92+
'Mg': 'Surface6',
93+
}
94+
8795
SURFACE_DOPANT_SYMBOLS_TO_NAMES = {
8896
'Xsurfaceone': 'Surface',
8997
'Xsurfacethree': 'Surface3',
@@ -117,7 +125,16 @@ class Dopant(MSONable):
117125
def __init__(self,
118126
symbol: str,
119127
molar_concentration: float,
120-
n_levels: int | None = None):
128+
n_levels: int | None = None,
129+
legacy_calc: bool = False):
130+
self.legacy_calc = legacy_calc
131+
132+
if self.legacy_calc:
133+
# If this is an older (legacy) calc, we need to convert the hacked
134+
# surface species to the current.
135+
if symbol in self.LEGACY_SURFACE_NAMES:
136+
symbol = self.LEGACY_SURFACE_NAMES[symbol]
137+
121138
if symbol in self.SURFACE_DOPANT_NAMES_TO_SYMBOLS:
122139
symbol = self.SURFACE_DOPANT_NAMES_TO_SYMBOLS[symbol]
123140
self.symbol = symbol

src/NanoParticleTools/util/visualization.py

Lines changed: 92 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -16,125 +16,117 @@
1616
}
1717

1818

19-
def plot_nanoparticle_from_arrays(radii: np.array,
20-
concentrations: np.array,
21-
dpi=150,
22-
as_np_array=False,
23-
elements=['Yb', 'Er', 'Nd']):
19+
def plot_nanoparticle(radii: np.ndarray | list[NanoParticleConstraint],
20+
concentrations: np.array = None,
21+
dopant_specifications: list[tuple] = None,
22+
dpi=150,
23+
as_np_array=False,
24+
elements=['Yb', 'Er', 'Nd'],
25+
ax: plt.Axes = None,
26+
emissions: float = None):
2427
if 'Y' not in elements:
28+
# Add Y, the host element
2529
elements = elements + ['Y']
2630

27-
# Fill in the concentrations with Y
28-
concentrations_with_y = np.concatenate(
29-
(concentrations, 1 - concentrations.sum(axis=1, keepdims=True)),
30-
axis=1)
31-
31+
if isinstance(radii[0], NanoParticleConstraint):
32+
# Convert this to an array
33+
radii = np.array([0] + [c.radius for c in radii])
34+
if not isinstance(radii, np.ndarray):
35+
# If it is a list, it is already in the format we require
36+
raise TypeError(
37+
'radii should be an array of radii or list of contraints')
38+
39+
if concentrations is None and dopant_specifications is None:
40+
raise RuntimeError(
41+
'Must specify one of concentrations or dopant specifications')
42+
elif dopant_specifications is not None:
43+
# convert this to an array
44+
n_layers = len(radii) - 1
45+
dopant_dict = [{key: 0 for key in elements} for _ in range(n_layers)]
46+
for dopant in dopant_specifications:
47+
dopant_dict[dopant[0]][dopant[2]] = dopant[1]
48+
49+
# Fill in the rest with 'Y'
50+
for layer in dopant_dict:
51+
layer['Y'] = 1 - sum(layer.values())
52+
53+
vals = [[layer[el] for el in elements] for layer in dopant_dict]
54+
concentrations = np.array(vals)
55+
elif concentrations is not None:
56+
# Add Y into the list
57+
if len(elements) != concentrations.shape[1]:
58+
concentrations = np.concatenate(
59+
(concentrations,
60+
1 - concentrations.sum(axis=1, keepdims=True)),
61+
axis=1)
62+
63+
concentrations = np.clip(concentrations, 0, 1)
3264
colors = [
3365
DEFAULT_COLOR_MAP[el]
3466
if el in DEFAULT_COLOR_MAP else DEFAULT_COLOR_MAP['Other']
3567
for el in elements
3668
]
37-
# cmap = plt.colormaps["tab10"]
38-
# colors = cmap(np.arange(4))
39-
# # colors[:3] = colors[1:]
40-
# colors[-1] = [1, 1, 1, 1]
41-
42-
fig = plt.figure(figsize=(5, 5), dpi=dpi)
43-
ax = fig.subplots()
4469

45-
for i in range(concentrations.shape[0], 0, -1):
46-
ax.pie(concentrations_with_y[i - 1],
47-
radius=radii[i] / radii[-1],
48-
colors=colors,
49-
wedgeprops=dict(edgecolor='k', linewidth=0.25),
50-
startangle=90)
51-
ax.legend(elements, loc='upper left', bbox_to_anchor=(0.84, 0.95))
52-
plt.tight_layout()
53-
if as_np_array:
54-
# If we haven't already shown or saved the plot, then we need to
55-
# draw the figure first.
56-
fig.canvas.draw()
57-
58-
# Now we can save it to a numpy array.
59-
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
60-
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
61-
62-
# Close the figure to remove it from the buffer
63-
plt.close(fig)
64-
return data
70+
if ax is None:
71+
# make a new axis
72+
fig = plt.figure(figsize=(5, 5), dpi=dpi)
73+
ax = fig.subplots()
74+
75+
for i in range(concentrations.shape[0], 0, -1):
76+
ax.pie(concentrations[i - 1],
77+
radius=radii[i] / radii[-1],
78+
colors=colors,
79+
wedgeprops=dict(edgecolor='w', linewidth=0.25),
80+
startangle=90)
81+
ax.legend(elements, loc='upper left', bbox_to_anchor=(0.84, 0.95))
82+
if emissions:
83+
plt.text(0.1,
84+
0.95,
85+
f'UV Intensity={np.power(10, -emissions)-100:.2f}',
86+
fontsize=20,
87+
transform=plt.gca().transAxes)
88+
plt.tight_layout()
89+
if as_np_array:
90+
# If we haven't already shown or saved the plot, then we need to
91+
# draw the figure first.
92+
fig.canvas.draw()
93+
94+
# Now we can save it to a numpy array.
95+
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
96+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
97+
98+
# Close the figure to remove it from the buffer
99+
plt.close(fig)
100+
return data
101+
else:
102+
return fig
65103
else:
66-
return fig
67-
68-
69-
def plot_nanoparticle(constraints,
70-
dopant_specifications,
71-
dpi=150,
72-
as_np_array=False,
73-
elements=['Yb', 'Er', 'Nd']):
74-
if 'Y' not in elements:
75-
elements = elements + ['Y']
76-
77-
n_layers = len(constraints)
78-
radii = [0] + [constraint.radius for constraint in constraints]
79-
dopant_dict = [{key: 0 for key in elements} for _ in range(n_layers)]
80-
for dopant in dopant_specifications:
81-
dopant_dict[dopant[0]][dopant[2]] = dopant[1]
82-
83-
# Fill in the rest with 'Y'
84-
for layer in dopant_dict:
85-
layer['Y'] = 1 - sum(layer.values())
86-
87-
vals = [[layer[el] for el in elements] for layer in dopant_dict]
88-
89-
return plot_nanoparticle_from_arrays(np.array(radii),
90-
np.array(vals),
91-
dpi=dpi,
92-
as_np_array=as_np_array,
93-
elements=elements)
94-
95-
96-
def plot_nanoparticle_on_ax(ax,
97-
constraints,
98-
dopant_specifications,
99-
elements=['Yb', 'Er', 'Nd']):
100-
if 'Y' not in elements:
101-
elements = ['Y'] + elements
102-
103-
n_layers = len(constraints)
104-
radii = [constraint.radius for constraint in constraints]
105-
dopant_dict = [{key: 0 for key in elements} for _ in range(n_layers)]
106-
for dopant in dopant_specifications:
107-
dopant_dict[dopant[0]][dopant[2]] = dopant[1]
108-
# Fill in the rest with 'Y'
109-
for layer in dopant_dict:
110-
layer['Y'] = np.round(1 - sum(layer.values()), 3)
111-
112-
vals = [[layer[el] for el in elements] for layer in dopant_dict]
113-
cmap = plt.colormaps["tab10"]
114-
colors = cmap(np.arange(4) * 4)
115-
colors[0] = [1, 1, 1, 1]
116-
117-
for i in list(range(n_layers - 1, -1, -1)):
118-
# print(vals[i])
119-
ax.pie(vals[i],
120-
radius=radii[i] / radii[-1],
121-
colors=colors,
122-
wedgeprops=dict(edgecolor='k'),
123-
startangle=90)
124-
ax.legend(elements, loc='upper left', bbox_to_anchor=(1, 1))
104+
for i in range(concentrations.shape[0], 0, -1):
105+
ax.pie(concentrations[i - 1],
106+
radius=radii[i] / radii[-1],
107+
colors=colors,
108+
wedgeprops=dict(edgecolor='w', linewidth=0.25),
109+
startangle=90)
110+
ax.legend(elements, loc='upper left', bbox_to_anchor=(0.84, 0.95))
111+
if emissions:
112+
plt.text(0.1,
113+
0.95,
114+
f'UV Intensity={np.power(10, -emissions)-100:.2f}',
115+
fontsize=20,
116+
transform=plt.gca().transAxes)
125117

126118

127119
def update(data, ax):
128-
constraints, dopants = data
129120
ax.clear()
130-
plot_nanoparticle_on_ax(ax, constraints, dopants)
121+
plot_nanoparticle(ax=ax, **data)
131122

132123

133124
def make_animation(frames: List[Tuple[NanoParticleConstraint, Tuple]],
134125
name: str = 'animation.mp4',
135-
fps: int = 30) -> None:
126+
fps: int = 30,
127+
dpi: int = 300) -> None:
136128

137-
fig = plt.figure(dpi=150)
129+
fig = plt.figure(dpi=dpi)
138130
ax = fig.subplots()
139131
anim = animation.FuncAnimation(fig, partial(update, ax=ax), frames=frames)
140132
anim.save(name, fps=fps)

0 commit comments

Comments
 (0)