Skip to content

Commit 3ed4c18

Browse files
committed
In plotting, plot all splits if available.
1 parent aba6f6e commit 3ed4c18

File tree

1 file changed

+39
-23
lines changed

1 file changed

+39
-23
lines changed

src/NanoParticleTools/machine_learning/util/visualization.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,29 +28,45 @@ def get_parity_plot(model,
2828
fig = plt.figure(**fig_kwargs)
2929
ax = fig.add_subplot(111)
3030

31-
uv_pred, uv_true = get_predictions(model, data_module.train_dataloader(),
32-
log, log_constant)
33-
ax.plot(uv_true.flatten(),
34-
uv_pred.flatten(),
35-
'o',
36-
alpha=0.2,
37-
label='Train Data')
38-
39-
uv_pred, uv_true = get_predictions(model, data_module.val_dataloader(),
40-
log, log_constant)
41-
ax.plot(uv_true.flatten(),
42-
uv_pred.flatten(),
43-
'X',
44-
alpha=0.2,
45-
label='Val Data')
46-
47-
uv_pred, uv_true = get_predictions(model, data_module.test_dataloader(),
48-
log, log_constant)
49-
ax.plot(uv_true.flatten(),
50-
uv_pred.flatten(),
51-
'D',
52-
alpha=0.2,
53-
label='Test Data')
31+
if data_module.train_dataset is not None:
32+
uv_pred, uv_true = get_predictions(model,
33+
data_module.train_dataloader(), log,
34+
log_constant)
35+
ax.plot(uv_true.flatten(),
36+
uv_pred.flatten(),
37+
'o',
38+
alpha=0.2,
39+
label='Train Data')
40+
41+
if data_module.val_dataset is not None:
42+
uv_pred, uv_true = get_predictions(model, data_module.val_dataloader(),
43+
log, log_constant)
44+
ax.plot(uv_true.flatten(),
45+
uv_pred.flatten(),
46+
'X',
47+
alpha=0.2,
48+
label='Val Data')
49+
50+
if data_module.iid_test_dataset is not None:
51+
uv_pred, uv_true = get_predictions(model,
52+
data_module.iid_test_dataloader(),
53+
log, log_constant)
54+
ax.plot(uv_true.flatten(),
55+
uv_pred.flatten(),
56+
'D',
57+
alpha=0.2,
58+
label='ID Test Data')
59+
60+
if data_module.test_dataset is not None:
61+
uv_pred, uv_true = get_predictions(model,
62+
data_module.test_dataloader(), log,
63+
log_constant)
64+
ax.plot(uv_true.flatten(),
65+
uv_pred.flatten(),
66+
'D',
67+
alpha=0.2,
68+
label='OOD Test Data')
69+
5470
ax.plot([0, max(max(ax.get_xlim()), max(ax.get_ylim()))],
5571
[0, max(max(ax.get_xlim()), max(ax.get_ylim()))], 'k--')
5672

0 commit comments

Comments
 (0)