Skip to content

Commit add0d50

Browse files
committed
update helper functions and optimize how DataFrames get created
1 parent a2584ba commit add0d50

File tree

1 file changed

+47
-40
lines changed
  • broker/cloud_run/lsst/variability

1 file changed

+47
-40
lines changed

broker/cloud_run/lsst/variability/main.py

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22
# -*- coding: UTF-8 -*-
33

4-
"""This module produces "value-added" lite alerts containing StetsonJ statistics on the DIA point source fluxes."""
4+
"""This module produces "value-added" lite alerts containing J indices on the DIA point source fluxes."""
55

66
import os
77
from typing import Dict
@@ -62,20 +62,20 @@ def run():
6262
except pittgoogle.exceptions.BadRequest as exc:
6363
return str(exc), HTTP_400
6464

65-
stetsonj_stats = _calculate_stetsonJ_statistics(alert_lite)
65+
j_dict = _classify(alert_lite)
6666
pg_variable = {"pg_variable": "unlikely"}
6767

6868
for band in ["g", "r", "u"]:
6969
if (
70-
stetsonj_stats.get(f"n_detections_{band}_band", 0) >= 30
71-
and stetsonj_stats.get(f"{band}_psfFluxStetsonJ", 0) > 20
70+
j_dict.get(f"n_detections_{band}_band", 0) >= 30
71+
and j_dict.get(f"{band}_psfFluxStetsonJ", 0) > 20
7272
):
7373
pg_variable = {"pg_variable": "likely"}
7474
break
7575

7676
TOPIC.publish(
7777
pittgoogle.Alert.from_dict(
78-
{"alert_lite": alert_lite.dict["alert_lite"], "variability": stetsonj_stats},
78+
{"alert_lite": alert_lite.dict["alert_lite"], "variability": j_dict},
7979
attributes={**alert_lite.attributes, **pg_variable},
8080
schema_name="default",
8181
)
@@ -84,14 +84,14 @@ def run():
8484
return "", HTTP_204
8585

8686

87-
def _calculate_stetsonJ_statistics(alert_lite: pittgoogle.Alert) -> Dict:
87+
def _classify(alert_lite: pittgoogle.Alert) -> Dict:
8888
"""Adapted from:
8989
https://github.com/lsst/meas_base/blob/e5cf12406b54a6312b9d6fa23fbd132cd7999387/python/lsst/meas/base/diaCalculationPlugins.py#L904
9090
9191
Compute the StetsonJ statistics on the DIA point source fluxes for each band.
9292
"""
93-
alert_lite_dict = alert_lite.dict["alert_lite"]
94-
alert_df = _create_dataframe(alert_lite_dict)
93+
94+
alert_df = _create_dataframe(alert_lite.dict["alert_lite"])
9595
bands = alert_df["band"].unique()
9696
outgoing_dict = {}
9797

@@ -118,32 +118,40 @@ def _calculate_stetsonJ_statistics(alert_lite: pittgoogle.Alert) -> Dict:
118118
return outgoing_dict
119119

120120

121-
def _create_dataframe(alert_dict: pittgoogle.Alert) -> "pd.DataFrame":
122-
"""Return a pandas DataFrame containing the source detections."""
121+
def _create_dataframe(alert_lite_dict: dict) -> pd.DataFrame:
122+
"""Create a DataFrame object from the alert lite dictionary."""
123123

124-
# sources and previous sources are expected to have the same fields
125-
sources_df = pd.DataFrame(
126-
[alert_dict.get("diaSource")] + (alert_dict.get("prvDiaSources") or [])
127-
)
128-
# sources and forced sources may have different fields
129-
forced_df = pd.DataFrame(alert_dict.get("prvDiaForcedSources") or [])
130-
131-
# use nullable integer data type to avoid converting ints to floats
132-
# for columns in one dataframe but not the other
133-
sources_ints = [c for c, v in sources_df.dtypes.items() if v == int]
134-
sources_df = sources_df.astype(
135-
{c: "Int64" for c in set(sources_ints) - set(forced_df.columns)}
136-
)
137-
forced_ints = [c for c, v in forced_df.dtypes.items() if v == int]
138-
forced_df = forced_df.astype({c: "Int64" for c in set(forced_ints) - set(sources_df.columns)})
124+
required_cols = [
125+
"band",
126+
"psfFlux",
127+
"psfFluxErr",
128+
] # columns required by to compute J index
129+
130+
# extract fields and create filtered DataFrames
131+
sources = [alert_lite_dict.get("diaSource")] + (alert_lite_dict.get("prvDiaSources") or [])
132+
forced_sources = alert_lite_dict.get("prvDiaForcedSources") or []
133+
sources_df = pd.DataFrame(filter_columns(sources, required_cols))
134+
forced_df = pd.DataFrame(filter_columns(forced_sources, required_cols))
135+
136+
# concatenate diaSource, prvDiaSources, and prvDiaForcedSources into a single DataFrame
137+
df = pd.concat([sources_df, forced_df], ignore_index=True)
139138

140-
_dataframe = pd.concat([sources_df, forced_df], ignore_index=True)
141-
return _dataframe
139+
return df
140+
141+
142+
def filter_columns(field_list, required_cols):
143+
"""Extract only relevant columns if they exist."""
144+
145+
return [
146+
{k: field.get(k) for k in required_cols if k in field}
147+
for field in field_list
148+
if field is not None
149+
]
142150

143151

144152
def _stetson_J(fluxes: np.ndarray, errors: np.ndarray) -> float:
145153
"""Adapted from:
146-
https://github.com/lsst/meas_base/blob/e5cf12406b54a6312b9d6fa23fbd132cd7999387/python/lsst/meas/base/diaCalculationPlugins.py#L949
154+
https://github.com/lsst/meas_base/blob/013ef565331c896a3fd73aefec294de42bc66371/python/lsst/meas/base/diaCalculationPlugins.py#L1279
147155
148156
Compute the single band StetsonJ statistic.
149157
@@ -172,21 +180,24 @@ def _stetson_J(fluxes: np.ndarray, errors: np.ndarray) -> float:
172180
return np.mean(np.sign(p_k) * np.sqrt(np.fabs(p_k)))
173181

174182

175-
def _stetson_mean(values: np.ndarray, errors: np.ndarray) -> float:
183+
def _stetson_mean(
184+
values: np.ndarray, errors: np.ndarray, mean=None, alpha=2.0, beta=2.0, n_iter=20, tol=1e-6
185+
) -> float:
176186
"""Adapted from:
177-
https://github.com/lsst/meas_base/blob/e5cf12406b54a6312b9d6fa23fbd132cd7999387/python/lsst/meas/base/diaCalculationPlugins.py#L979
187+
https://github.com/lsst/meas_base/blob/013ef565331c896a3fd73aefec294de42bc66371/python/lsst/meas/base/diaCalculationPlugins.py#L1309
178188
179-
Compute the Stetson mean of the fluxes which down-weights outliers.
180-
181-
Weighted biased on an error weighted difference scaled by a constant (1/``a``) and raised to the power beta. Higher
182-
betas more harshly penalize outliers and ``a`` sets the number of sigma where a weighted difference of 1 occurs.
189+
Compute the stetson mean of the fluxes which down-weights outliers. Weighted biased on an error weighted difference
190+
scaled by a constant (1/``a``) and raised to the power beta. Higher betas more harshly penalize outliers and ``a``
191+
sets the number of sigma where a weighted difference of 1 occurs.
183192
184193
Parameters
185194
----------
186195
values : `numpy.dnarray`, (N,)
187196
Input values to compute the mean of.
188197
errors : `numpy.ndarray`, (N,)
189198
Errors on the input values.
199+
mean : `float`
200+
Starting mean value or None.
190201
alpha : `float`
191202
Scalar down-weighting of the fractional difference. lower->more clipping. (Default value is 2.)
192203
beta : `float`
@@ -206,16 +217,12 @@ def _stetson_mean(values: np.ndarray, errors: np.ndarray) -> float:
206217
.. [1] Stetson, P. B., "On the Automatic Determination of Light-Curve Parameters for Cepheid Variables", PASP, 108,
207218
851S, 1996
208219
"""
209-
# define values
210-
alpha = 2.0
211-
beta = 2.0
212-
n_iter = 20
213-
tol = 1e-6
214220

215221
n_points = len(values)
216222
n_factor = np.sqrt(n_points / (n_points - 1))
217223
inv_var = 1 / errors**2
218-
mean = np.average(values, weights=inv_var)
224+
if mean is None:
225+
mean = np.average(values, weights=inv_var)
219226

220227
for _ in range(n_iter):
221228
chi = np.fabs(n_factor * (values - mean) / errors)

0 commit comments

Comments
 (0)