Skip to content

Commit

Permalink
In plotting, plot all splits if available.
Browse files Browse the repository at this point in the history
  • Loading branch information
sivonxay committed Aug 24, 2023
1 parent aba6f6e commit 3ed4c18
Showing 1 changed file with 39 additions and 23 deletions.
62 changes: 39 additions & 23 deletions src/NanoParticleTools/machine_learning/util/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,45 @@ def get_parity_plot(model,
fig = plt.figure(**fig_kwargs)
ax = fig.add_subplot(111)

uv_pred, uv_true = get_predictions(model, data_module.train_dataloader(),
log, log_constant)
ax.plot(uv_true.flatten(),
uv_pred.flatten(),
'o',
alpha=0.2,
label='Train Data')

uv_pred, uv_true = get_predictions(model, data_module.val_dataloader(),
log, log_constant)
ax.plot(uv_true.flatten(),
uv_pred.flatten(),
'X',
alpha=0.2,
label='Val Data')

uv_pred, uv_true = get_predictions(model, data_module.test_dataloader(),
log, log_constant)
ax.plot(uv_true.flatten(),
uv_pred.flatten(),
'D',
alpha=0.2,
label='Test Data')
if data_module.train_dataset is not None:
uv_pred, uv_true = get_predictions(model,
data_module.train_dataloader(), log,
log_constant)
ax.plot(uv_true.flatten(),
uv_pred.flatten(),
'o',
alpha=0.2,
label='Train Data')

if data_module.val_dataset is not None:
uv_pred, uv_true = get_predictions(model, data_module.val_dataloader(),
log, log_constant)
ax.plot(uv_true.flatten(),
uv_pred.flatten(),
'X',
alpha=0.2,
label='Val Data')

if data_module.iid_test_dataset is not None:
uv_pred, uv_true = get_predictions(model,
data_module.iid_test_dataloader(),
log, log_constant)
ax.plot(uv_true.flatten(),
uv_pred.flatten(),
'D',
alpha=0.2,
label='ID Test Data')

if data_module.test_dataset is not None:
uv_pred, uv_true = get_predictions(model,
data_module.test_dataloader(), log,
log_constant)
ax.plot(uv_true.flatten(),
uv_pred.flatten(),
'D',
alpha=0.2,
label='OOD Test Data')

ax.plot([0, max(max(ax.get_xlim()), max(ax.get_ylim()))],
[0, max(max(ax.get_xlim()), max(ax.get_ylim()))], 'k--')

Expand Down

0 comments on commit 3ed4c18

Please sign in to comment.