@@ -128,43 +128,51 @@ def _classify(alert_lite: pittgoogle.Alert) -> dict:
128128
129129def _format_for_classifier (alert_lite : pittgoogle .Alert ) -> pd .DataFrame :
130130 """Create a DataFrame for input to SuperNNova."""
131+
131132 alert_lite_dict = alert_lite .dict ["alert_lite" ]
132- alert_df = _create_dataframe (alert_lite_dict )
133- snn_df = pd .DataFrame (
134- data = {
135- # select a subset of columns and rename them for SuperNNova
136- # get_key returns the name that the survey uses for a given field
137- # for the full mapping, see alert.schema.map
138- "SNID" : [alert_lite_dict ["diaObject" ]["diaObjectId" ]] * len (alert_df .index ),
139- "FLT" : alert_df ["band" ],
140- "MJD" : alert_df ["midpointMjdTai" ],
141- "FLUXCAL" : alert_df ["psfFlux" ],
142- "FLUXCALERR" : alert_df ["psfFluxErr" ],
143- },
144- index = alert_df .index ,
133+ # create DataFrame from alert_lite_dict containing the columns SuperNNova expects
134+ filtered_df = _create_dataframe (alert_lite_dict )
135+ filtered_df = filtered_df .rename (
136+ columns = {
137+ "band" : "FLT" ,
138+ "midpointMjdTai" : "MJD" ,
139+ "psfFlux" : "FLUXCAL" ,
140+ "psfFluxErr" : "FLUXCALERR" ,
141+ }
145142 )
143+ filtered_df ["SNID" ] = alert_lite_dict ["diaObject" ]["diaObjectId" ]
144+ filtered_df = filtered_df [["SNID" , "FLT" , "MJD" , "FLUXCAL" , "FLUXCALERR" ]]
146145
147- return snn_df
146+ return filtered_df
148147
149148
150- def _create_dataframe (alert_dict : dict ) -> " pd.DataFrame" :
151- """Return a pandas DataFrame containing the source detections ."""
149+ def _create_dataframe (alert_lite_dict : dict ) -> pd .DataFrame :
150+ """Create a DataFrame object from the alert lite dictionary ."""
152151
153- # sources and previous sources are expected to have the same fields
154- sources_df = pd .DataFrame (
155- [alert_dict .get ("diaSource" )] + (alert_dict .get ("prvDiaSources" ) or [])
156- )
157- # sources and forced sources may have different fields
158- forced_df = pd .DataFrame (alert_dict .get ("prvDiaForcedSources" ) or [])
159-
160- # use nullable integer data type to avoid converting ints to floats
161- # for columns in one dataframe but not the other
162- sources_ints = [c for c , v in sources_df .dtypes .items () if v == int ]
163- sources_df = sources_df .astype (
164- {c : "Int64" for c in set (sources_ints ) - set (forced_df .columns )}
165- )
166- forced_ints = [c for c , v in forced_df .dtypes .items () if v == int ]
167- forced_df = forced_df .astype ({c : "Int64" for c in set (forced_ints ) - set (sources_df .columns )})
168- _dataframe = pd .concat ([sources_df , forced_df ], ignore_index = True )
152+ required_cols = [
153+ "band" ,
154+ "midpointMjdTai" ,
155+ "psfFlux" ,
156+ "psfFluxErr" ,
157+ ] # columns required by SuperNNova
158+
159+ # extract fields and create filtered DataFrames
160+ sources = [alert_lite_dict .get ("diaSource" )] + (alert_lite_dict .get ("prvDiaSources" ) or [])
161+ forced_sources = alert_lite_dict .get ("prvDiaForcedSources" ) or []
162+ sources_df = pd .DataFrame (filter_columns (sources , required_cols ))
163+ forced_df = pd .DataFrame (filter_columns (forced_sources , required_cols ))
164+
165+ # concatenate diaSource, prvDiaSources, and prvDiaForcedSources into a single DataFrame
166+ df = pd .concat ([sources_df , forced_df ], ignore_index = True )
167+
168+ return df
169+
170+
171+ def filter_columns (field_list , required_cols ):
172+ """Extract only relevant columns if they exist."""
169173
170- return _dataframe
174+ return [
175+ {k : field .get (k ) for k in required_cols if k in field }
176+ for field in field_list
177+ if field is not None
178+ ]
0 commit comments