Skip to content

Commit 08d142f

Browse files
authored
Merge pull request #34 from learningmatter-mit/vssr_pourbaix
Merge `vssr_pourbaix`
2 parents eb7d0b0 + 95937c9 commit 08d142f

File tree

18 files changed

+585
-424
lines changed

18 files changed

+585
-424
lines changed

models/foundation_models/chgnet/0.2.0/README.md

Lines changed: 0 additions & 74 deletions
This file was deleted.
Binary file not shown.

models/foundation_models/chgnet/0.3.0/README.md

Lines changed: 0 additions & 80 deletions
This file was deleted.
Binary file not shown.

nff/analysis/loss_plot.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,24 @@
44
from . import mpl_settings
55

66

7-
def plot_loss(energy_history, forces_history, figname, train_key="train", val_key="val"):
7+
def plot_loss(
8+
energy_history: dict,
9+
forces_history: dict,
10+
figname: str,
11+
train_key: str = "train",
12+
val_key: str = "val",
13+
) -> None:
814
"""Plot the loss history of the model.
9-
Args:
10-
energy_history (dict): energy loss history of the model for training and validation
11-
forces_history (dict): forces loss history of the model for training and validation
12-
figname (str): name of the figure
1315
14-
Returns:
15-
None
16+
Args:
17+
energy_history: energy loss history of the model for training and validation
18+
forces_history: forces loss history of the model for training and validation
19+
figname: name of the figure
20+
train_key: key for training data in the history dictionary
21+
val_key: key for validation data in the history dictionary
1622
"""
1723
epochs = np.arange(1, len(energy_history[train_key]) + 1)
18-
fig, ax_fig = plt.subplots(1, 2, figsize=(12, 6), dpi=mpl_settings.DPI)
24+
fig, ax_fig = plt.subplots(1, 2, figsize=(5, 2.5), dpi=mpl_settings.DPI)
1925
ax_fig[0].semilogy(epochs, energy_history[train_key], label="train", color=mpl_settings.colors[1])
2026
ax_fig[0].semilogy(epochs, energy_history[val_key], label="val", color=mpl_settings.colors[2])
2127
ax_fig[0].legend()

nff/analysis/mpl_settings.py

Lines changed: 64 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,44 @@
1+
from __future__ import annotations
2+
13
import json
24
from pathlib import Path
3-
from typing import List, Optional
5+
from typing import List
46

5-
import matplotlib
7+
import matplotlib as mpl
68
import matplotlib.pyplot as plt
79
import numpy as np
810

911
plt.style.use("default")
1012

11-
DPI = 100
12-
LINEWIDTH = 2
13-
FONTSIZE = 20
14-
LABELSIZE = 18
13+
dir_name = Path(__file__).parent
14+
15+
DPI = 300
16+
LINEWIDTH = 1.25
17+
FONTSIZE = 8
18+
LABELSIZE = 8
1519
ALPHA = 0.8
16-
LINE_MARKERSIZE = 15 * 25
17-
MARKERSIZE = 15
18-
GRIDSIZE = 40
19-
MAJOR_TICKLEN = 6
20-
MINOR_TICKLEN = 3
21-
TICKPADDING = 5
20+
MARKERSIZE = 25
21+
GRIDSIZE = 20
22+
MAJOR_TICKLEN = 4
23+
MINOR_TICKLEN = 2
24+
TICKPADDING = 3
2225
SECONDARY_CMAP = "inferno"
2326

24-
params = {
27+
custom_settings = {
2528
"mathtext.default": "regular",
2629
"font.family": "Arial",
2730
"font.size": FONTSIZE,
2831
"axes.labelsize": LABELSIZE,
2932
"axes.titlesize": FONTSIZE,
33+
"axes.linewidth": LINEWIDTH,
3034
"grid.linewidth": LINEWIDTH,
3135
"lines.linewidth": LINEWIDTH,
36+
"lines.color": "black",
37+
"axes.labelcolor": "black",
38+
"axes.edgecolor": "black",
39+
"axes.titlecolor": "black",
40+
"axes.titleweight": "bold",
41+
"axes.grid": False,
3242
"lines.markersize": MARKERSIZE,
3343
"xtick.major.size": MAJOR_TICKLEN,
3444
"xtick.minor.size": MINOR_TICKLEN,
@@ -38,66 +48,67 @@
3848
"ytick.minor.size": MINOR_TICKLEN,
3949
"ytick.major.pad": TICKPADDING,
4050
"ytick.minor.pad": TICKPADDING,
41-
"axes.linewidth": LINEWIDTH,
42-
"legend.fontsize": LABELSIZE,
43-
"figure.dpi": DPI,
44-
"savefig.dpi": DPI,
4551
"ytick.major.width": LINEWIDTH,
4652
"xtick.major.width": LINEWIDTH,
4753
"ytick.minor.width": LINEWIDTH,
4854
"xtick.minor.width": LINEWIDTH,
55+
"legend.fontsize": LABELSIZE,
56+
"figure.dpi": DPI,
57+
"savefig.dpi": DPI,
58+
"savefig.format": "png",
59+
"savefig.bbox": "tight",
60+
"savefig.pad_inches": 0.1,
61+
"figure.facecolor": "white",
4962
}
50-
plt.rcParams.update(params)
63+
plt.rcParams.update(custom_settings)
64+
5165

66+
def update_custom_settings(custom_settings: dict | None = custom_settings) -> None:
67+
"""Update the custom settings for Matplotlib.
5268
53-
def hex_to_rgb(value: str) -> tuple:
69+
Args:
70+
custom_settings: Custom settings for Matplotlib. Defaults to
71+
custom_settings.
5472
"""
55-
Converts hex to rgb colours
73+
current_settings = plt.rcParams.copy()
74+
new_settings = current_settings | custom_settings
75+
plt.rcParams.update(new_settings)
5676

57-
Parameters
58-
----------
59-
value: string of 6 characters representing a hex colour
6077

61-
Returns
62-
----------
63-
tuple of 3 integers representing the RGB values
64-
"""
78+
def hex_to_rgb(value: str) -> list[float]:
79+
"""Converts hex to rgb colors.
6580
81+
Args:
82+
value: string of 6 characters representing a hex color.
83+
"""
6684
value = value.strip("#") # removes hash symbol if present
6785
lv = len(value)
6886
return tuple(int(value[i : i + lv // 3], 16) for i in range(0, lv, lv // 3))
6987

7088

71-
def rgb_to_dec(value: list):
72-
"""
73-
Converts rgb to decimal colours (i.e. divides each value by 256)
89+
def rgb_to_dec(value: list[float]) -> list[float]:
90+
"""Converts rgb to decimal colors (i.e. divides each value by 256).
7491
75-
Parameters
76-
----------
77-
value: list of 3 integers representing the RGB values
92+
Args:
93+
value: string of 6 characters representing a hex color.
7894
79-
Returns
80-
----------
81-
list of 3 floats representing the RGB values
95+
Returns:
96+
list: length 3 of RGB values
8297
"""
83-
8498
return [v / 256 for v in value]
8599

86100

87-
def get_continuous_cmap(hex_list: List[str], float_list: Optional[List[float]] = None) -> matplotlib.colors.Colormap:
88-
"""
89-
Creates and returns a color map that can be used in heat map figures.
90-
If float_list is not provided, colour map graduates linearly between each color in hex_list.
91-
If float_list is provided, each color in hex_list is mapped to the respective location in float_list.
92-
93-
Parameters
94-
----------
95-
hex_list: list of hex code strings
96-
float_list: list of floats between 0 and 1, same length as hex_list. Must start with 0 and end with 1.
97-
98-
Returns
99-
----------
100-
Colormap
101+
def get_continuous_cmap(
102+
hex_list: list[str], float_list: list[float] | None = None
103+
) -> mpl.colors.LinearSegmentedColormap:
104+
"""Creates a color map that can be used in heat map figures. If float_list is not provided,
105+
color map graduates linearly between each color in hex_list. If float_list is provided,
106+
each color in hex_list is mapped to the respective location in float_list.
107+
108+
Args:
109+
hex_list: list of hex code strings
110+
float_list: list of floats between 0 and 1, same length as hex_list.
111+
Must start with 0 and end with 1.
101112
"""
102113
rgb_list = [rgb_to_dec(hex_to_rgb(i)) for i in hex_list]
103114
if float_list:
@@ -109,15 +120,12 @@ def get_continuous_cmap(hex_list: List[str], float_list: Optional[List[float]] =
109120
for num, col in enumerate(["red", "green", "blue"]):
110121
col_list = [[float_list[i], rgb_list[i][num], rgb_list[i][num]] for i in range(len(float_list))]
111122
cdict[col] = col_list
112-
cmp = matplotlib.colors.LinearSegmentedColormap("j_cmap", segmentdata=cdict, N=256)
113-
return cmp
123+
return mpl.colors.LinearSegmentedColormap("j_cmap", segmentdata=cdict, N=256)
114124

115125

116-
# colors taken from Johannes Dietschreit's script and interpolated with correct lightness and Bezier
126+
# Colors taken from Johannes Dietschreit's script and interpolated with correct lightness and Bezier
117127
# http://www.vis4.net/palettes/#/100|s|fce1a4,fabf7b,f08f6e,d12959,6e005f|ffffe0,ff005e,93003a|1|1
118128
hex_list: List[str]
119-
dir_name = Path(__file__).parent
120-
121129
with open(dir_name / "config/mpl_settings.json", "r") as f:
122130
hex_list = json.load(f)["plot_colors"]
123131

0 commit comments

Comments
 (0)