Skip to content

Commit a2584ba

Browse files
committed
optimize SuperNNova helper functions to reduce the amount of RAM needed to process an alert
1 parent 17f2f70 commit a2584ba

File tree

1 file changed

+41
-33
lines changed
  • broker/cloud_run/lsst/classify_snn

1 file changed

+41
-33
lines changed

broker/cloud_run/lsst/classify_snn/main.py

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -128,43 +128,51 @@ def _classify(alert_lite: pittgoogle.Alert) -> dict:
128128

129129
def _format_for_classifier(alert_lite: pittgoogle.Alert) -> pd.DataFrame:
130130
"""Create a DataFrame for input to SuperNNova."""
131+
131132
alert_lite_dict = alert_lite.dict["alert_lite"]
132-
alert_df = _create_dataframe(alert_lite_dict)
133-
snn_df = pd.DataFrame(
134-
data={
135-
# select a subset of columns and rename them for SuperNNova
136-
# get_key returns the name that the survey uses for a given field
137-
# for the full mapping, see alert.schema.map
138-
"SNID": [alert_lite_dict["diaObject"]["diaObjectId"]] * len(alert_df.index),
139-
"FLT": alert_df["band"],
140-
"MJD": alert_df["midpointMjdTai"],
141-
"FLUXCAL": alert_df["psfFlux"],
142-
"FLUXCALERR": alert_df["psfFluxErr"],
143-
},
144-
index=alert_df.index,
133+
# create DataFrame from alert_lite_dict containing the columns SuperNNova expects
134+
filtered_df = _create_dataframe(alert_lite_dict)
135+
filtered_df = filtered_df.rename(
136+
columns={
137+
"band": "FLT",
138+
"midpointMjdTai": "MJD",
139+
"psfFlux": "FLUXCAL",
140+
"psfFluxErr": "FLUXCALERR",
141+
}
145142
)
143+
filtered_df["SNID"] = alert_lite_dict["diaObject"]["diaObjectId"]
144+
filtered_df = filtered_df[["SNID", "FLT", "MJD", "FLUXCAL", "FLUXCALERR"]]
146145

147-
return snn_df
146+
return filtered_df
148147

149148

150-
def _create_dataframe(alert_dict: dict) -> "pd.DataFrame":
151-
"""Return a pandas DataFrame containing the source detections."""
149+
def _create_dataframe(alert_lite_dict: dict) -> pd.DataFrame:
150+
"""Create a DataFrame object from the alert lite dictionary."""
152151

153-
# sources and previous sources are expected to have the same fields
154-
sources_df = pd.DataFrame(
155-
[alert_dict.get("diaSource")] + (alert_dict.get("prvDiaSources") or [])
156-
)
157-
# sources and forced sources may have different fields
158-
forced_df = pd.DataFrame(alert_dict.get("prvDiaForcedSources") or [])
159-
160-
# use nullable integer data type to avoid converting ints to floats
161-
# for columns in one dataframe but not the other
162-
sources_ints = [c for c, v in sources_df.dtypes.items() if v == int]
163-
sources_df = sources_df.astype(
164-
{c: "Int64" for c in set(sources_ints) - set(forced_df.columns)}
165-
)
166-
forced_ints = [c for c, v in forced_df.dtypes.items() if v == int]
167-
forced_df = forced_df.astype({c: "Int64" for c in set(forced_ints) - set(sources_df.columns)})
168-
_dataframe = pd.concat([sources_df, forced_df], ignore_index=True)
152+
required_cols = [
153+
"band",
154+
"midpointMjdTai",
155+
"psfFlux",
156+
"psfFluxErr",
157+
] # columns required by SuperNNova
158+
159+
# extract fields and create filtered DataFrames
160+
sources = [alert_lite_dict.get("diaSource")] + (alert_lite_dict.get("prvDiaSources") or [])
161+
forced_sources = alert_lite_dict.get("prvDiaForcedSources") or []
162+
sources_df = pd.DataFrame(filter_columns(sources, required_cols))
163+
forced_df = pd.DataFrame(filter_columns(forced_sources, required_cols))
164+
165+
# concatenate diaSource, prvDiaSources, and prvDiaForcedSources into a single DataFrame
166+
df = pd.concat([sources_df, forced_df], ignore_index=True)
167+
168+
return df
169+
170+
171+
def filter_columns(field_list, required_cols):
172+
"""Extract only relevant columns if they exist."""
169173

170-
return _dataframe
174+
return [
175+
{k: field.get(k) for k in required_cols if k in field}
176+
for field in field_list
177+
if field is not None
178+
]

0 commit comments

Comments
 (0)