Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring in support of externally defined feature extractor libraries - i.e. LAISS #73

Merged
merged 5 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 23 additions & 100 deletions src/resspect/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from resspect.query_strategies import *
from resspect.query_budget_strategies import *
from resspect.metrics import get_snpcc_metric
from resspect.plugin_utils import fetch_classifier_class, fetch_query_strategy_class
from resspect.plugin_utils import (
fetch_classifier_class,
fetch_feature_extractor_class,
fetch_query_strategy_class
)
from resspect.filter_sets import FILTER_SETS

__all__ = ['DataBase']

Expand Down Expand Up @@ -234,6 +239,10 @@ def load_features_from_file(self, path_to_features_file: str, screen=False,
"Bump", or "Malanchev". Default is "Bazin".
"""

if survey not in ['DES', 'LSST']:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AmandaWasserman Is this logic correct, or should we support any survey that we find in the FILTER_SETS dictionary? i.e. DES, LSST, and SNPCC

raise ValueError('Only "DES" and "LSST" filters are ' + \
'implemented at this point!')

# read matrix with features
if '.tar.gz' in path_to_features_file:
tar = tarfile.open(path_to_features_file, 'r:gz')
Expand All @@ -249,109 +258,23 @@ def load_features_from_file(self, path_to_features_file: str, screen=False,
if 'queryable' not in data.keys():
data['queryable'] = [True for i in range(data.shape[0])]

# Get the list of feature names from the feature extractor class
feature_extractor_class = fetch_feature_extractor_class(feature_extractor)

#! Make this part work better with the different feature extractors
# list of features to use
if survey == 'DES':
if feature_extractor == "Bazin":
self.features_names = ['gA', 'gB', 'gt0', 'gtfall', 'gtrise', 'rA',
'rB', 'rt0', 'rtfall', 'rtrise', 'iA', 'iB',
'it0', 'itfall', 'itrise', 'zA', 'zB', 'zt0',
'ztfall', 'ztrise']
elif feature_extractor == 'Bump':
self.features_names = ['gp1', 'gp2', 'gp3', 'gmax_flux',
'rp1', 'rp2', 'rp3', 'rmax_flux',
'ip1', 'ip2', 'ip3', 'imax_flux',
'zp1', 'zp2', 'zp3', 'zmax_flux']
elif feature_extractor == 'Malanchev':
self.features_names = ['ganderson_darling_normal','ginter_percentile_range_5',
'gchi2','gstetson_K','gweighted_mean','gduration',
'gotsu_mean_diff','gotsu_std_lower', 'gotsu_std_upper',
'gotsu_lower_to_all_ratio', 'glinear_fit_slope',
'glinear_fit_slope_sigma','glinear_fit_reduced_chi2',
'randerson_darling_normal', 'rinter_percentile_range_5',
'rchi2', 'rstetson_K', 'rweighted_mean','rduration',
'rotsu_mean_diff','rotsu_std_lower', 'rotsu_std_upper',
'rotsu_lower_to_all_ratio', 'rlinear_fit_slope',
'rlinear_fit_slope_sigma','rlinear_fit_reduced_chi2',
'ianderson_darling_normal','iinter_percentile_range_5',
'ichi2', 'istetson_K', 'iweighted_mean','iduration',
'iotsu_mean_diff','iotsu_std_lower', 'iotsu_std_upper',
'iotsu_lower_to_all_ratio', 'ilinear_fit_slope',
'ilinear_fit_slope_sigma','ilinear_fit_reduced_chi2',
'zanderson_darling_normal','zinter_percentile_range_5',
'zchi2', 'zstetson_K', 'zweighted_mean','zduration',
'zotsu_mean_diff','zotsu_std_lower', 'zotsu_std_upper',
'zotsu_lower_to_all_ratio', 'zlinear_fit_slope',
'zlinear_fit_slope_sigma','zlinear_fit_reduced_chi2']

self.metadata_names = ['id', 'redshift', 'type', 'code',
'orig_sample', 'queryable']

if 'last_rmag' in data.keys():
self.metadata_names.append('last_rmag')
# Create the filter-feature names based on the survey.
survey_filters = FILTER_SETS[survey]
self.features_names = feature_extractor_class.get_features(survey_filters)

for name in self.telescope_names:
if 'cost_' + name in data.keys():
self.metadata_names = self.metadata_names + ['cost_' + name]

elif survey == 'LSST':
if feature_extractor == "Bazin":
self.features_names = ['uA', 'uB', 'ut0', 'utfall', 'utrise',
'gA', 'gB', 'gt0', 'gtfall', 'gtrise',
'rA', 'rB', 'rt0', 'rtfall', 'rtrise',
'iA', 'iB', 'it0', 'itfall', 'itrise',
'zA', 'zB', 'zt0', 'ztfall', 'ztrise',
'YA', 'YB', 'Yt0', 'Ytfall', 'Ytrise']
elif feature_extractor == "Malanchev":
self.features_names = ['uanderson_darling_normal','uinter_percentile_range_5',
'uchi2','ustetson_K','uweighted_mean','uduration',
'uotsu_mean_diff','uotsu_std_lower', 'uotsu_std_upper',
'uotsu_lower_to_all_ratio', 'ulinear_fit_slope',
'ulinear_fit_slope_sigma','ulinear_fit_reduced_chi2',
'ganderson_darling_normal','ginter_percentile_range_5',
'gchi2','gstetson_K','gweighted_mean','gduration',
'gotsu_mean_diff','gotsu_std_lower', 'gotsu_std_upper',
'gotsu_lower_to_all_ratio', 'glinear_fit_slope',
'glinear_fit_slope_sigma','glinear_fit_reduced_chi2',
'randerson_darling_normal', 'rinter_percentile_range_5',
'rchi2', 'rstetson_K', 'rweighted_mean','rduration',
'rotsu_mean_diff','rotsu_std_lower', 'rotsu_std_upper',
'rotsu_lower_to_all_ratio', 'rlinear_fit_slope',
'rlinear_fit_slope_sigma','rlinear_fit_reduced_chi2',
'ianderson_darling_normal','iinter_percentile_range_5',
'ichi2', 'istetson_K', 'iweighted_mean','iduration',
'iotsu_mean_diff','iotsu_std_lower', 'iotsu_std_upper',
'iotsu_lower_to_all_ratio', 'ilinear_fit_slope',
'ilinear_fit_slope_sigma','ilinear_fit_reduced_chi2',
'zanderson_darling_normal','zinter_percentile_range_5',
'zchi2', 'zstetson_K', 'zweighted_mean','zduration',
'zotsu_mean_diff','zotsu_std_lower', 'zotsu_std_upper',
'zotsu_lower_to_all_ratio', 'zlinear_fit_slope',
'zlinear_fit_slope_sigma','zlinear_fit_reduced_chi2',
'Yanderson_darling_normal','Yinter_percentile_range_5',
'Ychi2', 'Ystetson_K', 'Yweighted_mean','Yduration',
'Yotsu_mean_diff','Yotsu_std_lower', 'Yotsu_std_upper',
'Yotsu_lower_to_all_ratio', 'Ylinear_fit_slope',
'Ylinear_fit_slope_sigma','Ylinear_fit_reduced_chi2']

if 'objid' in data.keys():
self.metadata_names = ['objid', 'redshift', 'type', 'code',
'orig_sample', 'queryable']
elif 'id' in data.keys():
self.metadata_names = ['id', 'redshift', 'type', 'code',
'orig_sample', 'queryable']

if 'last_rmag' in data.keys():
self.metadata_names.append('last_rmag')
self.metadata_names = ['id', 'redshift', 'type', 'code', 'orig_sample', 'queryable']
if 'objid' in data.keys():
self.metadata_names = ['objid', 'redshift', 'type', 'code', 'orig_sample', 'queryable']

for name in self.telescope_names:
if 'cost_' + name in data.keys():
self.metadata_names = self.metadata_names + ['cost_' + name]
if 'last_rmag' in data.keys():
self.metadata_names.append('last_rmag')

else:
raise ValueError('Only "DES" and "LSST" filters are ' + \
'implemented at this point!')
for name in self.telescope_names:
if 'cost_' + name in data.keys():
self.metadata_names = self.metadata_names + ['cost_' + name]

if sample == None:
self.features = data[self.features_names].values
Expand Down
12 changes: 7 additions & 5 deletions src/resspect/feature_extractors/bazin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@


class Bazin(LightCurve):
feature_names = ['A', 'B', 't0', 'tfall', 'trise']


def __init__(self):
super().__init__()
self.features_names = ['A', 'B', 't0', 'tfall', 'trise']
self.num_features = len(Bazin.feature_names)

def evaluate(self, time: np.array) -> dict:
"""
Expand Down Expand Up @@ -67,7 +70,7 @@ def fit(self, band: str) -> np.ndarray:

# build filter flag
band_indices = self.photometry['band'] == band
if not sum(band_indices) > (len(self.features_names) - 1):
if not sum(band_indices) > (self.num_features - 1):
return np.array([])

# get info for this filter
Expand All @@ -85,11 +88,10 @@ def fit_all(self):
Performs bazin fit for all filters independently and concatenate results.
Populates the attributes: bazinfeatures.
"""
default_bazinfeatures = ['None'] * len(self.features_names)
default_bazinfeatures = ['None'] * self.num_features

if self.photometry.shape[0] < 1:
self.features = ['None'] * len(
self.features_names) * len(self.filters)
self.features = ['None'] * self.num_features * len(self.filters)

elif 'None' not in self.features:
self.features = []
Expand Down
11 changes: 7 additions & 4 deletions src/resspect/feature_extractors/bump.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@


class Bump(LightCurve):
feature_names = ['p1', 'p2', 'p3', 'time_shift', 'max_flux']


def __init__(self):
super().__init__()
self.features_names = ['p1', 'p2', 'p3', 'time_shift', 'max_flux']
self.num_features = len(Bump.feature_names)

def evaluate(self, time: np.array) -> dict:
"""
Expand Down Expand Up @@ -64,7 +67,7 @@ def fit(self, band: str) -> np.ndarray:

# build filter flag
band_indices = self.photometry['band'] == band
if not sum(band_indices) > (len(self.features_names) - 2):
if not sum(band_indices) > (self.num_features - 2):
return np.array([])

# get info for this filter
Expand All @@ -81,10 +84,10 @@ def fit_all(self):
Perform Bump fit for all filters independently and concatenate results.
Populates the attributes: bump_features.
"""
default_bump_features = ['None'] * len(self.features_names)
default_bump_features = ['None'] * self.num_features

if self.photometry.shape[0] < 1:
self.features = ['None'] * len(self.features_names) * len(self.filters)
self.features = ['None'] * self.num_features * len(self.filters)

elif 'None' not in self.features:
self.features = []
Expand Down
73 changes: 73 additions & 0 deletions src/resspect/feature_extractors/feature_extractor_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import itertools
from typing import List

def make_features_header(
filters: List[str],
features: List[str],
**kwargs
) -> list:
"""
This function returns header list for given filters and features. The default
header names are: ['id', 'redshift', 'type', 'code', 'orig_sample'].

Parameters
----------
filters : list
Filter values. e.g. ['g', 'r', 'i', 'z']
features : list
Feature values. e.g. ['A', 'B']
with_cost : bool
Flag for adding cost values. Default is False
kwargs
Can include the following flags:
- override_primary_columns: List[str] of primary columns to override the default ones
- with_queryable: flag for adding "queryable" column
- with_last_rmag: flag for adding "last_rmag" column
- with_cost: flag for adding "cost_4m" and "cost_8m" columns

Returns
-------
header
header list
"""

header = []
header.extend(['id', 'redshift', 'type', 'code', 'orig_sample'])

# There are rare instances where we need to override the primary columns
if kwargs.get('override_primary_columns', False):
header = kwargs.get('override_primary_columns')

if kwargs.get('with_queryable', False):
header.append('queryable')
if kwargs.get('with_last_rmag', False):
header.append('last_rmag')

#TODO: find where the 'with_cost' flag is used to make sure we apply there
if kwargs.get('with_cost', False):
header.extend(['cost_4m', 'cost_8m'])

# Create all pairs of filter + feature strings and append to the header
filter_features = create_filter_feature_names(filters, features)
header += filter_features

return header

def create_filter_feature_names(filters: List[str], features: List[str]) -> List[str]:
"""This function returns the list of concatenated filters and features. e.g.
filter = ['g', 'r'], features = ['A', 'B'] => ['gA', 'gB', 'rA', 'rB']

Parameters
----------
filters : List[str]
Filter name list
features : List[str]
Feature name list

Returns
-------
List[str]
List of filter-feature pairs.
"""

return [''.join(pair) for pair in itertools.product(filters, features)]
62 changes: 52 additions & 10 deletions src/resspect/feature_extractors/light_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@
import pandas as pd

from resspect.exposure_time_calculator import ExpTimeCalc
from resspect.lightcurves_utils import read_file
from resspect.lightcurves_utils import load_snpcc_photometry_df
from resspect.lightcurves_utils import get_photometry_with_id_name_and_snid
from resspect.lightcurves_utils import read_plasticc_full_photometry_data
from resspect.lightcurves_utils import load_plasticc_photometry_df
from resspect.lightcurves_utils import get_snpcc_sntype

from resspect.lightcurves_utils import (
get_photometry_with_id_name_and_snid,
get_snpcc_sntype,
load_plasticc_photometry_df,
load_snpcc_photometry_df,
read_file,
read_plasticc_full_photometry_data,
)
from resspect.feature_extractors.feature_extractor_utils import (
create_filter_feature_names,
make_features_header,
)

warnings.filterwarnings("ignore", category=RuntimeWarning)
logging.basicConfig(level=logging.INFO)
Expand All @@ -37,8 +42,8 @@ class LightCurve:

Attributes
----------
features_names: list
List of names of the feature extraction parameters.
feature_names: list
Class attribute, a list of names of the feature extraction parameters.
features: list
List with the 5 best-fit feature extraction parameters in all filters.
Concatenated from blue to red.
Expand Down Expand Up @@ -95,10 +100,11 @@ class LightCurve:

"""

feature_names = []

def __init__(self):
self.queryable = None
self.features = []
#self.features_names = ['p1', 'p2', 'p3', 'time_shift', 'max_flux']
self.dataset_name = ' '
self.exp_time = {}
self.filters = []
Expand Down Expand Up @@ -143,6 +149,42 @@ def fit_all(self):
"""
raise NotImplementedError()


@classmethod
def get_features(cls, filters: list) -> list[str]:
"""
Returns the header for the features extracted for all filters, excludes
non-feature columns (i.e. id, redshift, type, code, orig_sample, ...)

Parameters
----------
filters: list
List of broad band filters.

Returns
-------
list
"""
return create_filter_feature_names(filters, cls.feature_names)

@classmethod
def get_feature_header(cls, filters: list, **kwargs) -> list[str]:
"""
Returns the header for the features extracted.

Parameters
----------
filters: list
List of broad band filters.
kwargs: dict

Returns
-------
list
"""

return make_features_header(filters, cls.feature_names, **kwargs)

def _get_snpcc_photometry_raw_and_header(
self, lc_data: np.ndarray,
sntype_test_value: str = "-9") -> Tuple[np.ndarray, list]:
Expand Down
Loading
Loading