@@ -28,29 +28,45 @@ def get_parity_plot(model,
28
28
fig = plt .figure (** fig_kwargs )
29
29
ax = fig .add_subplot (111 )
30
30
31
- uv_pred , uv_true = get_predictions (model , data_module .train_dataloader (),
32
- log , log_constant )
33
- ax .plot (uv_true .flatten (),
34
- uv_pred .flatten (),
35
- 'o' ,
36
- alpha = 0.2 ,
37
- label = 'Train Data' )
38
-
39
- uv_pred , uv_true = get_predictions (model , data_module .val_dataloader (),
40
- log , log_constant )
41
- ax .plot (uv_true .flatten (),
42
- uv_pred .flatten (),
43
- 'X' ,
44
- alpha = 0.2 ,
45
- label = 'Val Data' )
46
-
47
- uv_pred , uv_true = get_predictions (model , data_module .test_dataloader (),
48
- log , log_constant )
49
- ax .plot (uv_true .flatten (),
50
- uv_pred .flatten (),
51
- 'D' ,
52
- alpha = 0.2 ,
53
- label = 'Test Data' )
31
+ if data_module .train_dataset is not None :
32
+ uv_pred , uv_true = get_predictions (model ,
33
+ data_module .train_dataloader (), log ,
34
+ log_constant )
35
+ ax .plot (uv_true .flatten (),
36
+ uv_pred .flatten (),
37
+ 'o' ,
38
+ alpha = 0.2 ,
39
+ label = 'Train Data' )
40
+
41
+ if data_module .val_dataset is not None :
42
+ uv_pred , uv_true = get_predictions (model , data_module .val_dataloader (),
43
+ log , log_constant )
44
+ ax .plot (uv_true .flatten (),
45
+ uv_pred .flatten (),
46
+ 'X' ,
47
+ alpha = 0.2 ,
48
+ label = 'Val Data' )
49
+
50
+ if data_module .iid_test_dataset is not None :
51
+ uv_pred , uv_true = get_predictions (model ,
52
+ data_module .iid_test_dataloader (),
53
+ log , log_constant )
54
+ ax .plot (uv_true .flatten (),
55
+ uv_pred .flatten (),
56
+ 'D' ,
57
+ alpha = 0.2 ,
58
+ label = 'ID Test Data' )
59
+
60
+ if data_module .test_dataset is not None :
61
+ uv_pred , uv_true = get_predictions (model ,
62
+ data_module .test_dataloader (), log ,
63
+ log_constant )
64
+ ax .plot (uv_true .flatten (),
65
+ uv_pred .flatten (),
66
+ 'D' ,
67
+ alpha = 0.2 ,
68
+ label = 'OOD Test Data' )
69
+
54
70
ax .plot ([0 , max (max (ax .get_xlim ()), max (ax .get_ylim ()))],
55
71
[0 , max (max (ax .get_xlim ()), max (ax .get_ylim ()))], 'k--' )
56
72
0 commit comments