@@ -83,24 +83,34 @@ def run():
8383 model_lite = ORACLE_lite (MODEL_PATH )
8484 oracle_classification = model_lite .predict ([_format_for_classifier (alert_lite )]).to_dict (orient = 'records' )[0 ]
8585
86+ level_1_class = _most_likely_class (oracle_classification , ['Transient' , 'Variable' ])
87+ level_2_class = _most_likely_class (oracle_classification , ['SN' , 'Fast' , 'Long' , # Transient
88+ 'Periodic' , 'AGN' ]) # Variable
89+ leaf_class = _most_likely_class (oracle_classification , ['SNIa' , 'SNIb/c' , 'SNIax' , 'SNI91bg' , 'SNII' , # SN
90+ 'KN' , 'Dwarf Novae' , 'uLens' , 'M-dwarf Flare' , # Fast
91+ 'SLSN' , 'TDE' , 'ILOT' , 'CART' , 'PISN' , # Long
92+ 'Cepheid' , 'RR Lyrae' , 'Delta Scuti' , 'EB' , # Periodic
93+ 'AGN' ]) # AGN
94+
8695 # publish
8796 outpt_dict = {
88- "diaObjectId" : alert_lite .dict ["alert_lite" ]["diaObject" ]["diaObjectId" ],
89- "diaSourceId" : alert_lite .dict ["alert_lite" ]["diaSource" ]["diaSourceId" ],
9097 "output" : oracle_classification ,
91- "predicted_level_1" : round (oracle_classification ),
92- "predicted_level_2" : round (oracle_classification ),
93- "predicted_leaf" : round (oracle_classification ),
98+ "predicted_level_1" : level_1_class [0 ],
99+ "predicted_level_1_prob" : level_1_class [1 ],
100+ "predicted_level_2" : level_2_class [0 ],
101+ "predicted_level_2_prob" : level_2_class [1 ],
102+ "predicted_leaf" : leaf_class [0 ],
103+ "predicted_leaf_prob" : leaf_class [1 ]
94104 }
95105
96106 TOPIC .publish (
97107 pittgoogle .Alert .from_dict (
98- payload = {" alert_lite" : alert_lite .dict , "SCONE" : outpt_dict },
108+ payload = {' alert_lite' : alert_lite .dict [ 'alert_lite' ], 'ORACLE' : outpt_dict },
99109 attributes = {
100110 ** alert_lite .attributes ,
101- "pg_scone_class" : outpt_dict ["predicted_class" ],
111+ 'pg_oracle_class' : outpt_dict ['predicted_leaf' ],
102112 },
103- schema_name = " default" ,
113+ schema_name = ' default' ,
104114 )
105115 )
106116
@@ -111,6 +121,7 @@ def y_to_Y(band):
111121 return 'Y'
112122 return band
113123
124+ # this could use improvement
114125def get_photflag (flux ):
115126 if flux [1 ] > 5 * abs (flux [0 ]):
116127 return 1024
@@ -120,7 +131,7 @@ def get_photflag(flux):
120131 return 4096
121132
122133def _format_for_classifier (alert : pittgoogle .Alert ) -> pd .DataFrame :
123- """Create a DataFrame for input to SCONE ."""
134+ """Create a DataFrame for input to ORACLE ."""
124135 alert_df = alert .dataframe
125136 t = Table ([alert_df [alert .get_key ("mjd" )[1 ]],
126137 alert_df [alert .get_key ("filter" )[1 ]],
@@ -133,4 +144,13 @@ def _format_for_classifier(alert: pittgoogle.Alert) -> pd.DataFrame:
133144 MJD_min = min (t ['MJD' ])
134145 t ['MJD' ] = [MJD - MJD_min for MJD in t ['MJD' ]]
135146 t .sort ('MJD' )
136- return t .to_pandas ()
147+ return t .to_pandas ()
148+
149+ def _most_likely_class (probability_dict : dict , keys : list ) -> tuple [str , float ]:
150+ max_val = 0
151+ max_class = None
152+ for key in keys :
153+ if max_val < probability_dict [key ]:
154+ max_val = probability_dict [key ]
155+ max_class = key
156+ return (max_class , max_val )
0 commit comments