Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 41 additions & 33 deletions broker/cloud_run/lsst/classify_snn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,43 +128,51 @@ def _classify(alert_lite: pittgoogle.Alert) -> dict:

def _format_for_classifier(alert_lite: pittgoogle.Alert) -> pd.DataFrame:
"""Create a DataFrame for input to SuperNNova."""

alert_lite_dict = alert_lite.dict["alert_lite"]
alert_df = _create_dataframe(alert_lite_dict)
snn_df = pd.DataFrame(
data={
# select a subset of columns and rename them for SuperNNova
# get_key returns the name that the survey uses for a given field
# for the full mapping, see alert.schema.map
"SNID": [alert_lite_dict["diaObject"]["diaObjectId"]] * len(alert_df.index),
"FLT": alert_df["band"],
"MJD": alert_df["midpointMjdTai"],
"FLUXCAL": alert_df["psfFlux"],
"FLUXCALERR": alert_df["psfFluxErr"],
},
index=alert_df.index,
# create DataFrame from alert_lite_dict containing the columns SuperNNova expects
filtered_df = _create_dataframe(alert_lite_dict)
filtered_df = filtered_df.rename(
columns={
"band": "FLT",
"midpointMjdTai": "MJD",
"psfFlux": "FLUXCAL",
"psfFluxErr": "FLUXCALERR",
}
)
filtered_df["SNID"] = alert_lite_dict["diaObject"]["diaObjectId"]
filtered_df = filtered_df[["SNID", "FLT", "MJD", "FLUXCAL", "FLUXCALERR"]]

return snn_df
return filtered_df


def _create_dataframe(alert_dict: dict) -> "pd.DataFrame":
"""Return a pandas DataFrame containing the source detections."""
def _create_dataframe(alert_lite_dict: dict) -> pd.DataFrame:
"""Create a DataFrame object from the alert lite dictionary."""

# sources and previous sources are expected to have the same fields
sources_df = pd.DataFrame(
[alert_dict.get("diaSource")] + (alert_dict.get("prvDiaSources") or [])
)
# sources and forced sources may have different fields
forced_df = pd.DataFrame(alert_dict.get("prvDiaForcedSources") or [])

# use nullable integer data type to avoid converting ints to floats
# for columns in one dataframe but not the other
sources_ints = [c for c, v in sources_df.dtypes.items() if v == int]
sources_df = sources_df.astype(
{c: "Int64" for c in set(sources_ints) - set(forced_df.columns)}
)
forced_ints = [c for c, v in forced_df.dtypes.items() if v == int]
forced_df = forced_df.astype({c: "Int64" for c in set(forced_ints) - set(sources_df.columns)})
_dataframe = pd.concat([sources_df, forced_df], ignore_index=True)
required_cols = [
"band",
"midpointMjdTai",
"psfFlux",
"psfFluxErr",
] # columns required by SuperNNova

# extract fields and create filtered DataFrames
sources = [alert_lite_dict.get("diaSource")] + (alert_lite_dict.get("prvDiaSources") or [])
forced_sources = alert_lite_dict.get("prvDiaForcedSources") or []
sources_df = pd.DataFrame(filter_columns(sources, required_cols))
forced_df = pd.DataFrame(filter_columns(forced_sources, required_cols))

# concatenate diaSource, prvDiaSources, and prvDiaForcedSources into a single DataFrame
df = pd.concat([sources_df, forced_df], ignore_index=True)

return df


def filter_columns(field_list, required_cols):
"""Extract only relevant columns if they exist."""

return _dataframe
return [
{k: field.get(k) for k in required_cols if k in field}
for field in field_list
if field is not None
]
87 changes: 47 additions & 40 deletions broker/cloud_run/lsst/variability/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-

"""This module produces "value-added" lite alerts containing StetsonJ statistics on the DIA point source fluxes."""
"""This module produces "value-added" lite alerts containing J indices on the DIA point source fluxes."""

import os
from typing import Dict
Expand Down Expand Up @@ -62,20 +62,20 @@
except pittgoogle.exceptions.BadRequest as exc:
return str(exc), HTTP_400

stetsonj_stats = _calculate_stetsonJ_statistics(alert_lite)
j_dict = _classify(alert_lite)
pg_variable = {"pg_variable": "unlikely"}

for band in ["g", "r", "u"]:
if (
stetsonj_stats.get(f"n_detections_{band}_band", 0) >= 30
and stetsonj_stats.get(f"{band}_psfFluxStetsonJ", 0) > 20
j_dict.get(f"n_detections_{band}_band", 0) >= 30
and j_dict.get(f"{band}_psfFluxStetsonJ", 0) > 20
):
pg_variable = {"pg_variable": "likely"}
break

TOPIC.publish(
pittgoogle.Alert.from_dict(
{"alert_lite": alert_lite.dict["alert_lite"], "variability": stetsonj_stats},
{"alert_lite": alert_lite.dict["alert_lite"], "variability": j_dict},
attributes={**alert_lite.attributes, **pg_variable},
schema_name="default",
)
Expand All @@ -84,14 +84,14 @@
return "", HTTP_204


def _calculate_stetsonJ_statistics(alert_lite: pittgoogle.Alert) -> Dict:
def _classify(alert_lite: pittgoogle.Alert) -> Dict:
"""Adapted from:
https://github.com/lsst/meas_base/blob/e5cf12406b54a6312b9d6fa23fbd132cd7999387/python/lsst/meas/base/diaCalculationPlugins.py#L904

Compute the StetsonJ statistics on the DIA point source fluxes for each band.
"""
alert_lite_dict = alert_lite.dict["alert_lite"]
alert_df = _create_dataframe(alert_lite_dict)

alert_df = _create_dataframe(alert_lite.dict["alert_lite"])
bands = alert_df["band"].unique()
outgoing_dict = {}

Expand All @@ -118,32 +118,40 @@
return outgoing_dict


def _create_dataframe(alert_dict: pittgoogle.Alert) -> "pd.DataFrame":
"""Return a pandas DataFrame containing the source detections."""
def _create_dataframe(alert_lite_dict: dict) -> pd.DataFrame:
"""Create a DataFrame object from the alert lite dictionary."""

# sources and previous sources are expected to have the same fields
sources_df = pd.DataFrame(
[alert_dict.get("diaSource")] + (alert_dict.get("prvDiaSources") or [])
)
# sources and forced sources may have different fields
forced_df = pd.DataFrame(alert_dict.get("prvDiaForcedSources") or [])

# use nullable integer data type to avoid converting ints to floats
# for columns in one dataframe but not the other
sources_ints = [c for c, v in sources_df.dtypes.items() if v == int]
sources_df = sources_df.astype(
{c: "Int64" for c in set(sources_ints) - set(forced_df.columns)}
)
forced_ints = [c for c, v in forced_df.dtypes.items() if v == int]
forced_df = forced_df.astype({c: "Int64" for c in set(forced_ints) - set(sources_df.columns)})
required_cols = [
"band",
"psfFlux",
"psfFluxErr",
] # columns required by to compute J index

# extract fields and create filtered DataFrames
sources = [alert_lite_dict.get("diaSource")] + (alert_lite_dict.get("prvDiaSources") or [])
forced_sources = alert_lite_dict.get("prvDiaForcedSources") or []
sources_df = pd.DataFrame(filter_columns(sources, required_cols))
forced_df = pd.DataFrame(filter_columns(forced_sources, required_cols))

# concatenate diaSource, prvDiaSources, and prvDiaForcedSources into a single DataFrame
df = pd.concat([sources_df, forced_df], ignore_index=True)

_dataframe = pd.concat([sources_df, forced_df], ignore_index=True)
return _dataframe
return df


def filter_columns(field_list, required_cols):
"""Extract only relevant columns if they exist."""

return [
{k: field.get(k) for k in required_cols if k in field}
for field in field_list
if field is not None
]


def _stetson_J(fluxes: np.ndarray, errors: np.ndarray) -> float:
"""Adapted from:
https://github.com/lsst/meas_base/blob/e5cf12406b54a6312b9d6fa23fbd132cd7999387/python/lsst/meas/base/diaCalculationPlugins.py#L949
https://github.com/lsst/meas_base/blob/013ef565331c896a3fd73aefec294de42bc66371/python/lsst/meas/base/diaCalculationPlugins.py#L1279

Compute the single band StetsonJ statistic.

Expand Down Expand Up @@ -172,21 +180,24 @@
return np.mean(np.sign(p_k) * np.sqrt(np.fabs(p_k)))


def _stetson_mean(values: np.ndarray, errors: np.ndarray) -> float:
def _stetson_mean(

Check warning on line 183 in broker/cloud_run/lsst/variability/main.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

broker/cloud_run/lsst/variability/main.py#L183

Too many positional arguments (7/5)
values: np.ndarray, errors: np.ndarray, mean=None, alpha=2.0, beta=2.0, n_iter=20, tol=1e-6
) -> float:
"""Adapted from:
https://github.com/lsst/meas_base/blob/e5cf12406b54a6312b9d6fa23fbd132cd7999387/python/lsst/meas/base/diaCalculationPlugins.py#L979
https://github.com/lsst/meas_base/blob/013ef565331c896a3fd73aefec294de42bc66371/python/lsst/meas/base/diaCalculationPlugins.py#L1309

Compute the Stetson mean of the fluxes which down-weights outliers.

Weighted biased on an error weighted difference scaled by a constant (1/``a``) and raised to the power beta. Higher
betas more harshly penalize outliers and ``a`` sets the number of sigma where a weighted difference of 1 occurs.
Compute the stetson mean of the fluxes which down-weights outliers. Weighted biased on an error weighted difference
scaled by a constant (1/``a``) and raised to the power beta. Higher betas more harshly penalize outliers and ``a``
sets the number of sigma where a weighted difference of 1 occurs.

Parameters
----------
values : `numpy.dnarray`, (N,)
Input values to compute the mean of.
errors : `numpy.ndarray`, (N,)
Errors on the input values.
mean : `float`
Starting mean value or None.
alpha : `float`
Scalar down-weighting of the fractional difference. lower->more clipping. (Default value is 2.)
beta : `float`
Expand All @@ -206,16 +217,12 @@
.. [1] Stetson, P. B., "On the Automatic Determination of Light-Curve Parameters for Cepheid Variables", PASP, 108,
851S, 1996
"""
# define values
alpha = 2.0
beta = 2.0
n_iter = 20
tol = 1e-6

n_points = len(values)
n_factor = np.sqrt(n_points / (n_points - 1))
inv_var = 1 / errors**2
mean = np.average(values, weights=inv_var)
if mean is None:
mean = np.average(values, weights=inv_var)

for _ in range(n_iter):
chi = np.fabs(n_factor * (values - mean) / errors)
Expand Down