From 074863bd3f4de3c27133992a3d7f0e33ab642c1a Mon Sep 17 00:00:00 2001 From: Drew Oldag Date: Fri, 1 Nov 2024 16:43:21 -0700 Subject: [PATCH 1/5] Fairly substantial refactor to remove hardcoded column names from `lightcurve_utils`. Dynamically generate the column names as needed using methods defined in the `LightCurve` feature extractor subclasses. --- src/resspect/database.py | 88 ++-------- src/resspect/feature_extractors/bazin.py | 12 +- src/resspect/feature_extractors/bump.py | 11 +- .../feature_extractor_utils.py | 74 +++++++++ .../feature_extractors/light_curve.py | 41 +++-- src/resspect/feature_extractors/malanchev.py | 42 ++--- src/resspect/filter_sets.py | 7 + src/resspect/fit_lightcurves.py | 99 +++++------- src/resspect/lightcurves_utils.py | 150 ------------------ src/resspect/time_domain_plasticc.py | 25 +-- src/resspect/time_domain_snpcc.py | 28 ++-- .../resspect/test_feature_extractor_utils.py | 58 +++++++ 12 files changed, 284 insertions(+), 351 deletions(-) create mode 100644 src/resspect/feature_extractors/feature_extractor_utils.py create mode 100644 src/resspect/filter_sets.py create mode 100644 tests/resspect/test_feature_extractor_utils.py diff --git a/src/resspect/database.py b/src/resspect/database.py index b3ec7f78..f242883b 100644 --- a/src/resspect/database.py +++ b/src/resspect/database.py @@ -19,7 +19,13 @@ from resspect.query_strategies import * from resspect.query_budget_strategies import * from resspect.metrics import get_snpcc_metric -from resspect.plugin_utils import fetch_classifier_class, fetch_query_strategy_class +from resspect.plugin_utils import ( + fetch_classifier_class, + fetch_feature_extractor_class, + fetch_query_strategy_class +) +from resspect.filter_sets import FILTER_SETS +from resspect.feature_extractors.feature_extractor_utils import create_filter_feature_names __all__ = ['DataBase'] @@ -249,42 +255,15 @@ def load_features_from_file(self, path_to_features_file: str, screen=False, if 'queryable' not in data.keys(): data['queryable'] = [True for i in range(data.shape[0])] + # Get the list of feature names from the feature extractor class + feature_extractor_class = fetch_feature_extractor_class(feature_extractor) + feature_names = feature_extractor_class.feature_names - #! Make this part work better with the different feature extractors - # list of features to use - if survey == 'DES': - if feature_extractor == "Bazin": - self.features_names = ['gA', 'gB', 'gt0', 'gtfall', 'gtrise', 'rA', - 'rB', 'rt0', 'rtfall', 'rtrise', 'iA', 'iB', - 'it0', 'itfall', 'itrise', 'zA', 'zB', 'zt0', - 'ztfall', 'ztrise'] - elif feature_extractor == 'Bump': - self.features_names = ['gp1', 'gp2', 'gp3', 'gmax_flux', - 'rp1', 'rp2', 'rp3', 'rmax_flux', - 'ip1', 'ip2', 'ip3', 'imax_flux', - 'zp1', 'zp2', 'zp3', 'zmax_flux'] - elif feature_extractor == 'Malanchev': - self.features_names = ['ganderson_darling_normal','ginter_percentile_range_5', - 'gchi2','gstetson_K','gweighted_mean','gduration', - 'gotsu_mean_diff','gotsu_std_lower', 'gotsu_std_upper', - 'gotsu_lower_to_all_ratio', 'glinear_fit_slope', - 'glinear_fit_slope_sigma','glinear_fit_reduced_chi2', - 'randerson_darling_normal', 'rinter_percentile_range_5', - 'rchi2', 'rstetson_K', 'rweighted_mean','rduration', - 'rotsu_mean_diff','rotsu_std_lower', 'rotsu_std_upper', - 'rotsu_lower_to_all_ratio', 'rlinear_fit_slope', - 'rlinear_fit_slope_sigma','rlinear_fit_reduced_chi2', - 'ianderson_darling_normal','iinter_percentile_range_5', - 'ichi2', 'istetson_K', 'iweighted_mean','iduration', - 'iotsu_mean_diff','iotsu_std_lower', 'iotsu_std_upper', - 'iotsu_lower_to_all_ratio', 'ilinear_fit_slope', - 'ilinear_fit_slope_sigma','ilinear_fit_reduced_chi2', - 'zanderson_darling_normal','zinter_percentile_range_5', - 'zchi2', 'zstetson_K', 'zweighted_mean','zduration', - 'zotsu_mean_diff','zotsu_std_lower', 'zotsu_std_upper', - 'zotsu_lower_to_all_ratio', 'zlinear_fit_slope', - 'zlinear_fit_slope_sigma','zlinear_fit_reduced_chi2'] + # Create the filter-feature names based on the survey. + survey_filters = FILTER_SETS[survey] + self.features_names = create_filter_feature_names(survey_filters, feature_names) + if survey == 'DES': self.metadata_names = ['id', 'redshift', 'type', 'code', 'orig_sample', 'queryable'] @@ -296,45 +275,6 @@ def load_features_from_file(self, path_to_features_file: str, screen=False, self.metadata_names = self.metadata_names + ['cost_' + name] elif survey == 'LSST': - if feature_extractor == "Bazin": - self.features_names = ['uA', 'uB', 'ut0', 'utfall', 'utrise', - 'gA', 'gB', 'gt0', 'gtfall', 'gtrise', - 'rA', 'rB', 'rt0', 'rtfall', 'rtrise', - 'iA', 'iB', 'it0', 'itfall', 'itrise', - 'zA', 'zB', 'zt0', 'ztfall', 'ztrise', - 'YA', 'YB', 'Yt0', 'Ytfall', 'Ytrise'] - elif feature_extractor == "Malanchev": - self.features_names = ['uanderson_darling_normal','uinter_percentile_range_5', - 'uchi2','ustetson_K','uweighted_mean','uduration', - 'uotsu_mean_diff','uotsu_std_lower', 'uotsu_std_upper', - 'uotsu_lower_to_all_ratio', 'ulinear_fit_slope', - 'ulinear_fit_slope_sigma','ulinear_fit_reduced_chi2', - 'ganderson_darling_normal','ginter_percentile_range_5', - 'gchi2','gstetson_K','gweighted_mean','gduration', - 'gotsu_mean_diff','gotsu_std_lower', 'gotsu_std_upper', - 'gotsu_lower_to_all_ratio', 'glinear_fit_slope', - 'glinear_fit_slope_sigma','glinear_fit_reduced_chi2', - 'randerson_darling_normal', 'rinter_percentile_range_5', - 'rchi2', 'rstetson_K', 'rweighted_mean','rduration', - 'rotsu_mean_diff','rotsu_std_lower', 'rotsu_std_upper', - 'rotsu_lower_to_all_ratio', 'rlinear_fit_slope', - 'rlinear_fit_slope_sigma','rlinear_fit_reduced_chi2', - 'ianderson_darling_normal','iinter_percentile_range_5', - 'ichi2', 'istetson_K', 'iweighted_mean','iduration', - 'iotsu_mean_diff','iotsu_std_lower', 'iotsu_std_upper', - 'iotsu_lower_to_all_ratio', 'ilinear_fit_slope', - 'ilinear_fit_slope_sigma','ilinear_fit_reduced_chi2', - 'zanderson_darling_normal','zinter_percentile_range_5', - 'zchi2', 'zstetson_K', 'zweighted_mean','zduration', - 'zotsu_mean_diff','zotsu_std_lower', 'zotsu_std_upper', - 'zotsu_lower_to_all_ratio', 'zlinear_fit_slope', - 'zlinear_fit_slope_sigma','zlinear_fit_reduced_chi2', - 'Yanderson_darling_normal','Yinter_percentile_range_5', - 'Ychi2', 'Ystetson_K', 'Yweighted_mean','Yduration', - 'Yotsu_mean_diff','Yotsu_std_lower', 'Yotsu_std_upper', - 'Yotsu_lower_to_all_ratio', 'Ylinear_fit_slope', - 'Ylinear_fit_slope_sigma','Ylinear_fit_reduced_chi2'] - if 'objid' in data.keys(): self.metadata_names = ['objid', 'redshift', 'type', 'code', 'orig_sample', 'queryable'] diff --git a/src/resspect/feature_extractors/bazin.py b/src/resspect/feature_extractors/bazin.py index d1fb0e66..a3296f72 100644 --- a/src/resspect/feature_extractors/bazin.py +++ b/src/resspect/feature_extractors/bazin.py @@ -11,9 +11,12 @@ class Bazin(LightCurve): + feature_names = ['A', 'B', 't0', 'tfall', 'trise'] + + def __init__(self): super().__init__() - self.features_names = ['A', 'B', 't0', 'tfall', 'trise'] + self.num_features = len(Bazin.feature_names) def evaluate(self, time: np.array) -> dict: """ @@ -67,7 +70,7 @@ def fit(self, band: str) -> np.ndarray: # build filter flag band_indices = self.photometry['band'] == band - if not sum(band_indices) > (len(self.features_names) - 1): + if not sum(band_indices) > (self.num_features - 1): return np.array([]) # get info for this filter @@ -85,11 +88,10 @@ def fit_all(self): Performs bazin fit for all filters independently and concatenate results. Populates the attributes: bazinfeatures. """ - default_bazinfeatures = ['None'] * len(self.features_names) + default_bazinfeatures = ['None'] * self.num_features if self.photometry.shape[0] < 1: - self.features = ['None'] * len( - self.features_names) * len(self.filters) + self.features = ['None'] * self.num_features * len(self.filters) elif 'None' not in self.features: self.features = [] diff --git a/src/resspect/feature_extractors/bump.py b/src/resspect/feature_extractors/bump.py index 51811d5b..c7a680ec 100644 --- a/src/resspect/feature_extractors/bump.py +++ b/src/resspect/feature_extractors/bump.py @@ -11,9 +11,12 @@ class Bump(LightCurve): + feature_names = ['p1', 'p2', 'p3', 'time_shift', 'max_flux'] + + def __init__(self): super().__init__() - self.features_names = ['p1', 'p2', 'p3', 'time_shift', 'max_flux'] + self.num_features = len(Bump.feature_names) def evaluate(self, time: np.array) -> dict: """ @@ -64,7 +67,7 @@ def fit(self, band: str) -> np.ndarray: # build filter flag band_indices = self.photometry['band'] == band - if not sum(band_indices) > (len(self.features_names) - 2): + if not sum(band_indices) > (self.num_features - 2): return np.array([]) # get info for this filter @@ -81,10 +84,10 @@ def fit_all(self): Perform Bump fit for all filters independently and concatenate results. Populates the attributes: bump_features. """ - default_bump_features = ['None'] * len(self.features_names) + default_bump_features = ['None'] * self.num_features if self.photometry.shape[0] < 1: - self.features = ['None'] * len(self.features_names) * len(self.filters) + self.features = ['None'] * self.num_features * len(self.filters) elif 'None' not in self.features: self.features = [] diff --git a/src/resspect/feature_extractors/feature_extractor_utils.py b/src/resspect/feature_extractors/feature_extractor_utils.py new file mode 100644 index 00000000..3b3f07d2 --- /dev/null +++ b/src/resspect/feature_extractors/feature_extractor_utils.py @@ -0,0 +1,74 @@ +import itertools +from typing import List + +def make_features_header( + filters: List[str], + features: List[str], + **kwargs + ) -> list: + """ + This function returns header list for given filters and features. The default + header names are: ['id', 'redshift', 'type', 'code', 'orig_sample']. + + Parameters + ---------- + filters : list + Filter values. e.g. ['g', 'r', 'i', 'z'] + features : list + Feature values. e.g. ['A', 'B'] + with_cost : bool + Flag for adding cost values. Default is False + kwargs + Can include the following flags: + - override_primary_columns: List[str] of primary columns to override the default ones + - with_queryable: flag for adding "queryable" column + - with_last_rmag: flag for adding "last_rmag" column + - with_cost: flag for adding "cost_4m" and "cost_8m" columns + + Returns + ------- + header + header list + """ + + header = [] + header.extend(['id', 'redshift', 'type', 'code', 'orig_sample']) + + # There are rare instances where we need to override the primary columns + if kwargs.get('override_primary_columns', False): + header = kwargs.get('override_primary_columns') + + + if kwargs.get('with_queryable', False): + header.append('queryable') + if kwargs.get('with_last_rmag', False): + header.append('last_rmag') + + #TODO: find where the 'with_cost' flag is used to make sure we apply there + if kwargs.get('with_cost', False): + header.extend(['cost_4m', 'cost_8m']) + + # Create all pairs of filter + feature strings and append to the header + filter_features = create_filter_feature_names(filters, features) + header += filter_features + + return header + +def create_filter_feature_names(filters: List[str], features: List[str]) -> List[str]: + """This function returns the list of concatenated filters and features. e.g. + filter = ['g', 'r'], features = ['A', 'B'] => ['gA', 'gB', 'rA', 'rB'] + + Parameters + ---------- + filters : List[str] + Filter name list + features : List[str] + Feature name list + + Returns + ------- + List[str] + List of filter-feature pairs. + """ + + return [''.join(pair) for pair in itertools.product(filters, features)] \ No newline at end of file diff --git a/src/resspect/feature_extractors/light_curve.py b/src/resspect/feature_extractors/light_curve.py index 076ae6de..22d1a52c 100644 --- a/src/resspect/feature_extractors/light_curve.py +++ b/src/resspect/feature_extractors/light_curve.py @@ -17,13 +17,15 @@ import pandas as pd from resspect.exposure_time_calculator import ExpTimeCalc -from resspect.lightcurves_utils import read_file -from resspect.lightcurves_utils import load_snpcc_photometry_df -from resspect.lightcurves_utils import get_photometry_with_id_name_and_snid -from resspect.lightcurves_utils import read_plasticc_full_photometry_data -from resspect.lightcurves_utils import load_plasticc_photometry_df -from resspect.lightcurves_utils import get_snpcc_sntype - +from resspect.lightcurves_utils import ( + get_photometry_with_id_name_and_snid, + get_snpcc_sntype, + load_plasticc_photometry_df, + load_snpcc_photometry_df, + read_file, + read_plasticc_full_photometry_data, +) +from resspect.feature_extractors.feature_extractor_utils import make_features_header warnings.filterwarnings("ignore", category=RuntimeWarning) logging.basicConfig(level=logging.INFO) @@ -37,8 +39,8 @@ class LightCurve: Attributes ---------- - features_names: list - List of names of the feature extraction parameters. + feature_names: list + Class attribute, a list of names of the feature extraction parameters. features: list List with the 5 best-fit feature extraction parameters in all filters. Concatenated from blue to red. @@ -95,10 +97,11 @@ class LightCurve: """ + feature_names = [] + def __init__(self): self.queryable = None self.features = [] - #self.features_names = ['p1', 'p2', 'p3', 'time_shift', 'max_flux'] self.dataset_name = ' ' self.exp_time = {} self.filters = [] @@ -143,6 +146,24 @@ def fit_all(self): """ raise NotImplementedError() + @classmethod + def get_feature_header(cls, filters: list, **kwargs) -> list[str]: + """ + Returns the header for the features extracted. + + Parameters + ---------- + filters: list + List of broad band filters. + kwargs: dict + + Returns + ------- + list + """ + + return make_features_header(filters, cls.feature_names, **kwargs) + def _get_snpcc_photometry_raw_and_header( self, lc_data: np.ndarray, sntype_test_value: str = "-9") -> Tuple[np.ndarray, list]: diff --git a/src/resspect/feature_extractors/malanchev.py b/src/resspect/feature_extractors/malanchev.py index d294e84c..ee27c771 100644 --- a/src/resspect/feature_extractors/malanchev.py +++ b/src/resspect/feature_extractors/malanchev.py @@ -17,23 +17,28 @@ class Malanchev(LightCurve): + feature_names = [ + 'anderson_darling_normal', + 'inter_percentile_range_5', + 'chi2', + 'stetson_K', + 'weighted_mean', + 'duration', + 'otsu_mean_diff', + 'otsu_std_lower', + 'otsu_std_upper', + 'otsu_lower_to_all_ratio', + 'linear_fit_slope', + 'linear_fit_slope_sigma', + 'linear_fit_reduced_chi2' + ] + + def __init__(self): super().__init__() - self.features_names = ['anderson_darling_normal', - 'inter_percentile_range_5', - 'chi2', - 'stetson_K', - 'weighted_mean', - 'duration', - 'otsu_mean_diff', - 'otsu_std_lower', - 'otsu_std_upper', - 'otsu_lower_to_all_ratio', - 'linear_fit_slope', - 'linear_fit_slope_sigma', - 'linear_fit_reduced_chi2'] + self.num_features = len(Malanchev.feature_names) + - def fit(self, band: str) -> np.ndarray: """ Extracts malanchev-light-curve features for one filter. @@ -90,7 +95,6 @@ def fit(self, band: str) -> np.ndarray: check=False) - def fit_all_points(self): """ Extracts Malanchev's light_curve features for all data points in all filters together. @@ -142,16 +146,16 @@ def fit_all_points(self): sorted = True, check = False) + def fit_all(self): """ Performs malanchev-light-curve feature extraction for all filters independently and concatenate results. Populates the attributes: mlcfeatures. """ - default_mlcfeatures = ['None'] * len(self.features_names) + default_mlcfeatures = ['None'] * self.num_features if self.photometry.shape[0] < 1: - self.features = ['None'] * len( - self.features_names) * len(self.filters) + self.features = ['None'] * self.num_features * len(self.filters) elif 'None' not in self.features: self.features = [] @@ -162,4 +166,4 @@ def fit_all(self): else: self.features.extend(default_mlcfeatures) else: - self.features.extend(default_mlcfeatures) \ No newline at end of file + self.features.extend(default_mlcfeatures) diff --git a/src/resspect/filter_sets.py b/src/resspect/filter_sets.py new file mode 100644 index 00000000..11f0dd0b --- /dev/null +++ b/src/resspect/filter_sets.py @@ -0,0 +1,7 @@ +"""These are the lists of survey filters.""" + +FILTER_SETS = { + "SNPCC": ["g", "r", "i", "z"], + "DES": ["g", "r", "i", "z"], + "LSST": ["u", "g", "r", "i", "z", "Y"], +} diff --git a/src/resspect/fit_lightcurves.py b/src/resspect/fit_lightcurves.py index fe43bda7..2794cfcc 100644 --- a/src/resspect/fit_lightcurves.py +++ b/src/resspect/fit_lightcurves.py @@ -14,22 +14,18 @@ import os from copy import copy from itertools import repeat -from typing import IO +from typing import IO, List, Union import numpy as np import pandas as pd -from resspect.lightcurves_utils import get_resspect_header_data -from resspect.lightcurves_utils import read_plasticc_full_photometry_data -from resspect.lightcurves_utils import SNPCC_FEATURES_HEADER -from resspect.lightcurves_utils import TOM_FEATURES_HEADER -from resspect.lightcurves_utils import TOM_MALANCHEV_FEATURES_HEADER -from resspect.lightcurves_utils import SNPCC_MALANCHEV_FEATURES_HEADER -from resspect.lightcurves_utils import find_available_key_name_in_header -from resspect.lightcurves_utils import PLASTICC_TARGET_TYPES -from resspect.lightcurves_utils import PLASTICC_RESSPECT_FEATURES_HEADER -from resspect.lightcurves_utils import BUMP_HEADERS -from resspect.lightcurves_utils import make_features_header +from resspect.lightcurves_utils import ( + read_plasticc_full_photometry_data, + find_available_key_name_in_header, + PLASTICC_TARGET_TYPES, + PLASTICC_RESSPECT_FEATURES_HEADER, +) +from resspect.filter_sets import FILTER_SETS from resspect.plugin_utils import fetch_feature_extractor_class from resspect.tom_client import TomClient @@ -119,12 +115,8 @@ def fit_snpcc( feature_extractor: str, default Bazin Function used for feature extraction. """ - if feature_extractor == 'Bazin': - header = SNPCC_FEATURES_HEADER - elif feature_extractor == 'Malanchev': - header = SNPCC_MALANCHEV_FEATURES_HEADER - elif feature_extractor == 'Bump': - header = BUMP_HEADERS["snpcc_header"] + feature_extractor_class = fetch_feature_extractor_class(feature_extractor) + header = feature_extractor_class.get_feature_header(filters=FILTER_SETS['SNPCC']) files_list = os.listdir(path_to_data_dir) files_list = [each_file for each_file in files_list @@ -262,7 +254,7 @@ def _TOM_sample_fit( def fit_TOM(data_dic: dict, output_features_file: str, number_of_processors: int = MAX_NUMBER_OF_PROCESSES, - feature_extractor: str = 'bazin'): + feature_extractor: str = 'Bazin'): """ Perform fit to all objects from the TOM data. @@ -274,13 +266,12 @@ def fit_TOM(data_dic: dict, output_features_file: str, Path to output file where results should be stored. number_of_processors: int, default 1 Number of cpu processes to use. - feature_extractor: str, default bazin + feature_extractor: str, default Bazin Function used for feature extraction. """ - if feature_extractor == 'bazin': - header = TOM_FEATURES_HEADER - elif feature_extractor == 'malanchev': - header = TOM_MALANCHEV_FEATURES_HEADER + + feature_extractor_class = fetch_feature_extractor_class(feature_extractor) + header = feature_extractor_class.get_feature_header(filters=FILTER_SETS['LSST']) multi_process = multiprocessing.Pool(number_of_processors) logging.info("Starting TOM " + feature_extractor + " fit...") @@ -328,11 +319,16 @@ def _sample_fit( return light_curve_data -def fit(data_dic: dict, output_features_file: str, - number_of_processors: int = MAX_NUMBER_OF_PROCESSES, - feature_extractor: str = 'bazin', filters: list = ['SNPCC'], - features: list = [], type: str = 'unspecified', one_code: list = [10], - additional_info: list = []): +def fit( + data_dic: dict, + output_features_file: str, + number_of_processors: int = MAX_NUMBER_OF_PROCESSES, + feature_extractor: str = 'Bazin', + filters: Union[str, List[str]] = 'SNPCC', + type: str = 'unspecified', + one_code: list = [10], + additional_info: list = [] + ): """ Perform fit to all objects from a generalized dataset. @@ -345,13 +341,10 @@ def fit(data_dic: dict, output_features_file: str, Path to output file where results should be stored. number_of_processors: int, default 1 Number of cpu processes to use. - feature_extractor: str, default bazin + feature_extractor: str, default Bazin Function used for feature extraction. filters: list - List of filters to be used. Or SNPCC/LSST. - features: list - List of features to be used. Or default features with respect to - the feature extractor. + List of filters to be used, or a key in FILTER_SETS. type: str Type of data: train, test, validation, pool one_code: list @@ -367,7 +360,7 @@ def fit(data_dic: dict, output_features_file: str, 'objectid' object id 'photometry' - dictionary containing keys ''mjd', 'band', 'flux', 'fluxerr'. + dictionary containing keys ''mjd', 'band', 'flux', 'fluxerr'. each entry contains a list of mjd, band, flux, fluxerr for each observation 'redshift' @@ -381,34 +374,14 @@ def fit(data_dic: dict, output_features_file: str, 'RA' 'dec' """ - if feature_extractor == 'bazin': - header = TOM_FEATURES_HEADER - elif feature_extractor == 'malanchev': - header = TOM_MALANCHEV_FEATURES_HEADER - - if 'SNPCC' in filters: - filters = ['g', 'r', 'i', 'z'] - elif 'LSST' in filters: - filters = ['u', 'g', 'r', 'i', 'z', 'Y'] - - if feature_extractor == 'bazin': - features = ['A', 'B', 't0', 'tfall', 'trise'] - elif feature_extractor == 'malanchev': - features = ['anderson_darling_normal', - 'inter_percentile_range_5', - 'chi2', - 'stetson_K', - 'weighted_mean', - 'duration', - 'otsu_mean_diff', - 'otsu_std_lower', - 'otsu_std_upper', - 'otsu_lower_to_all_ratio', - 'linear_fit_slope', - 'linear_fit_slope_sigma', - 'linear_fit_reduced_chi2'] - - header = make_features_header(filters, features) + + # if `filters` is a key in FILTER_SETS, then use the corresponding value + # otherwise, assume `filters` is a list of filter strings like `['g', 'r']`. + if isinstance(filters, str) and filters in FILTER_SETS: + filters = FILTER_SETS[filters] + + feature_extractor_class = fetch_feature_extractor_class(feature_extractor) + header = feature_extractor_class.get_feature_header(filters) multi_process = multiprocessing.Pool(number_of_processors) if feature_extractor != None: diff --git a/src/resspect/lightcurves_utils.py b/src/resspect/lightcurves_utils.py index 119645ea..83ec13aa 100644 --- a/src/resspect/lightcurves_utils.py +++ b/src/resspect/lightcurves_utils.py @@ -14,40 +14,6 @@ from resspect.snana_fits_to_pd import read_fits - -BAZIN_HEADERS = { - 'plasticc_header': [ - 'id', 'redshift', 'type', 'code', 'sample', 'queryable', 'last_rmag', - 'uA', 'uB', 'ut0', 'utfall', 'utrise', 'gA', 'gB', 'gt0', 'gtfall', - 'gtrise', 'rA', 'rB', 'rt0', 'rtfall', 'rtrise', 'iA', 'iB', 'it0', - 'itfall', 'itrise', 'zA', 'zB', 'zt0', 'ztfall', 'ztrise', 'YA', 'YB', - 'Yt0', 'Ytfall', 'Ytrise'], - 'plasticc_header_with_cost': [ - 'id', 'redshift', 'type', 'code', 'sample', 'queryable', 'last_rmag', - 'cost_4m', 'cost_8m', 'uA', 'uB', 'ut0', 'utfall', 'utrise', 'gA', 'gB', - 'gt0', 'gtfall', 'gtrise', 'rA', 'rB', 'rt0', 'rtfall', 'rtrise', 'iA', - 'iB', 'it0', 'itfall', 'itrise', 'zA', 'zB', 'zt0', 'ztfall', 'ztrise', - 'YA', 'YB', 'Yt0', 'Ytfall', 'Ytrise'], - 'snpcc_header': [ - 'id', 'redshift', 'type', 'code', 'orig_sample', 'queryable', - 'last_rmag', 'gA', 'gB', 'gt0', 'gtfall', 'gtrise', 'rA', 'rB', - 'rt0', 'rtfall', 'rtrise', 'iA', 'iB', 'it0', 'itfall', 'itrise', - 'zA', 'zB', 'zt0', 'ztfall', 'ztrise'], - 'snpcc_header_with_cost': [ - 'id', 'redshift', 'type', 'code', 'orig_sample', 'queryable', - 'last_rmag', 'cost_4m', 'cost_8m', 'gA', 'gB', 'gt0', 'gtfall', - 'gtrise', 'rA', 'rB', 'rt0', 'rtfall', 'rtrise', 'iA', 'iB', 'it0', - 'itfall', 'itrise', 'zA', 'zB', 'zt0', 'ztfall', 'ztrise'] -} - -BUMP_HEADERS = { - 'snpcc_header': ['gp1', 'gp2', 'gp3', 'gtime_shift', 'gmax_flux', - 'rp1', 'rp2', 'rp3', 'rtime_shift', 'rmax_flux', - 'ip1', 'ip2', 'ip3', 'itime_shift', 'imax_flux', - 'zp1', 'zp2', 'zp3', 'ztime_shift', 'zmax_flux'] -} - - SNPCC_LC_MAPPINGS = { "snii": {2, 3, 4, 12, 15, 17, 19, 20, 21, 24, 25, 26, 27, 30, 31, 32, 33, 34, 35, 36, 37, 38, @@ -56,60 +22,6 @@ 18, 22, 23, 29, 45, 28} } -SNPCC_FEATURES_HEADER = [ - 'id', 'redshift', 'type', 'code', 'orig_sample', - 'gA', 'gB', 'gt0', 'gtfall', 'gtrise', 'rA', 'rB', - 'rt0', 'rtfall', 'rtrise', 'iA', 'iB', 'it0', 'itfall', - 'itrise', 'zA', 'zB', 'zt0', 'ztfall', 'ztrise' -] - -SNPCC_MALANCHEV_FEATURES_HEADER = [ - 'id', 'redshift', 'type', 'code', 'orig_sample', - 'ganderson_darling_normal','ginter_percentile_range_5', - 'gchi2','gstetson_K','gweighted_mean','gduration', 'gotsu_mean_diff','gotsu_std_lower', 'gotsu_std_upper', - 'gotsu_lower_to_all_ratio', 'glinear_fit_slope', 'glinear_fit_slope_sigma','glinear_fit_reduced_chi2', - 'randerson_darling_normal', 'rinter_percentile_range_5', - 'rchi2', 'rstetson_K', 'rweighted_mean','rduration', 'rotsu_mean_diff','rotsu_std_lower', 'rotsu_std_upper', - 'rotsu_lower_to_all_ratio', 'rlinear_fit_slope', 'rlinear_fit_slope_sigma','rlinear_fit_reduced_chi2', - 'ianderson_darling_normal','iinter_percentile_range_5', - 'ichi2', 'istetson_K', 'iweighted_mean','iduration', 'iotsu_mean_diff','iotsu_std_lower', 'iotsu_std_upper', - 'iotsu_lower_to_all_ratio', 'ilinear_fit_slope', 'ilinear_fit_slope_sigma','ilinear_fit_reduced_chi2', - 'zanderson_darling_normal','zinter_percentile_range_5', - 'zchi2', 'zstetson_K', 'zweighted_mean','zduration', 'zotsu_mean_diff','zotsu_std_lower', 'zotsu_std_upper', - 'zotsu_lower_to_all_ratio', 'zlinear_fit_slope', 'zlinear_fit_slope_sigma','zlinear_fit_reduced_chi2' -] - -TOM_FEATURES_HEADER = [ - 'id', 'redshift', 'type', 'code', 'orig_sample', - 'uA', 'uB', 'ut0', 'utfall', 'utrise', - 'gA', 'gB', 'gt0', 'gtfall', 'gtrise', 'rA', 'rB', - 'rt0', 'rtfall', 'rtrise', 'iA', 'iB', 'it0', 'itfall', - 'itrise', 'zA', 'zB', 'zt0', 'ztfall', 'ztrise', - 'YA', 'YB', 'Yt0', 'Ytfall', 'Ytrise' -] - -TOM_MALANCHEV_FEATURES_HEADER = [ - 'id', 'redshift', 'type', 'code', 'orig_sample', - 'uanderson_darling_normal','uinter_percentile_range_5', - 'uchi2','ustetson_K','uweighted_mean','uduration', 'uotsu_mean_diff','uotsu_std_lower', 'uotsu_std_upper', - 'uotsu_lower_to_all_ratio', 'ulinear_fit_slope', 'ulinear_fit_slope_sigma','ulinear_fit_reduced_chi2', - 'ganderson_darling_normal','ginter_percentile_range_5', - 'gchi2','gstetson_K','gweighted_mean','gduration', 'gotsu_mean_diff','gotsu_std_lower', 'gotsu_std_upper', - 'gotsu_lower_to_all_ratio', 'glinear_fit_slope', 'glinear_fit_slope_sigma','glinear_fit_reduced_chi2', - 'randerson_darling_normal', 'rinter_percentile_range_5', - 'rchi2', 'rstetson_K', 'rweighted_mean','rduration', 'rotsu_mean_diff','rotsu_std_lower', 'rotsu_std_upper', - 'rotsu_lower_to_all_ratio', 'rlinear_fit_slope', 'rlinear_fit_slope_sigma','rlinear_fit_reduced_chi2', - 'ianderson_darling_normal','iinter_percentile_range_5', - 'ichi2', 'istetson_K', 'iweighted_mean','iduration', 'iotsu_mean_diff','iotsu_std_lower', 'iotsu_std_upper', - 'iotsu_lower_to_all_ratio', 'ilinear_fit_slope', 'ilinear_fit_slope_sigma','ilinear_fit_reduced_chi2', - 'zanderson_darling_normal','zinter_percentile_range_5', - 'zchi2', 'zstetson_K', 'zweighted_mean','zduration', 'zotsu_mean_diff','zotsu_std_lower', 'zotsu_std_upper', - 'zotsu_lower_to_all_ratio', 'zlinear_fit_slope', 'zlinear_fit_slope_sigma','zlinear_fit_reduced_chi2', - 'Yanderson_darling_normal','Yinter_percentile_range_5', - 'Ychi2','Ystetson_K','Yweighted_mean','Yduration', 'Yotsu_mean_diff','Yotsu_std_lower', 'Yotsu_std_upper', - 'Yotsu_lower_to_all_ratio', 'Ylinear_fit_slope', 'Ylinear_fit_slope_sigma','Ylinear_fit_reduced_chi2' -] - PLASTICC_RESSPECT_FEATURES_HEADER = [ 'id', 'redshift', 'type', 'code', 'orig_sample', 'uA', 'uB', 'ut0', 'utfall', 'utrise', 'gA', 'gB', 'gt0', 'gtfall','gtrise', 'rA', 'rB', @@ -131,68 +43,6 @@ SNPCC_CANONICAL_FEATURES = ['z', 'g_pkmag', 'r_pkmag', 'i_pkmag', 'z_pkmag', 'g_SNR', 'r_SNR', 'i_SNR', 'z_SNR'] -MALANCHEV_HEADERS = { - 'snpcc_header': ['id', 'redshift', 'type', 'code', 'orig_sample', 'queryable', 'last_rmag', - 'ganderson_darling_normal', 'ginter_percentile_range_5', - 'gchi2', 'gstetson_K', 'gweighted_mean', 'gduration', 'gotsu_mean_diff', 'gotsu_std_lower', 'gotsu_std_upper', - 'gotsu_lower_to_all_ratio', 'glinear_fit_slope', 'glinear_fit_slope_sigma', 'glinear_fit_reduced_chi2', - 'randerson_darling_normal', 'rinter_percentile_range_5', - 'rchi2', 'rstetson_K', 'rweighted_mean', 'rduration', 'rotsu_mean_diff', 'rotsu_std_lower', 'rotsu_std_upper', - 'rotsu_lower_to_all_ratio', 'rlinear_fit_slope', 'rlinear_fit_slope_sigma', 'rlinear_fit_reduced_chi2', - 'ianderson_darling_normal', 'iinter_percentile_range_5', - 'ichi2', 'istetson_K', 'iweighted_mean', 'iduration', 'iotsu_mean_diff', 'iotsu_std_lower', 'iotsu_std_upper', - 'iotsu_lower_to_all_ratio', 'ilinear_fit_slope', 'ilinear_fit_slope_sigma', 'ilinear_fit_reduced_chi2', - 'zanderson_darling_normal', 'zinter_percentile_range_5', - 'zchi2', 'zstetson_K', 'zweighted_mean', 'zduration', 'zotsu_mean_diff', 'zotsu_std_lower', 'zotsu_std_upper', - 'zotsu_lower_to_all_ratio', 'zlinear_fit_slope', 'zlinear_fit_slope_sigma', 'zlinear_fit_reduced_chi2'], - 'snpcc_header_with_cost': ['id', 'redshift', 'type', 'code', 'orig_sample', 'queryable', 'last_rmag', - 'cost_4m', 'cost_8m', - 'ganderson_darling_normal', 'ginter_percentile_range_5', - 'gchi2', 'gstetson_K', 'gweighted_mean', 'gduration', 'gotsu_mean_diff', 'gotsu_std_lower', 'gotsu_std_upper', - 'gotsu_lower_to_all_ratio', 'glinear_fit_slope', 'glinear_fit_slope_sigma', 'glinear_fit_reduced_chi2', - 'randerson_darling_normal', 'rinter_percentile_range_5', - 'rchi2', 'rstetson_K', 'rweighted_mean', 'rduration', 'rotsu_mean_diff', 'rotsu_std_lower', 'rotsu_std_upper', - 'rotsu_lower_to_all_ratio', 'rlinear_fit_slope', 'rlinear_fit_slope_sigma', 'rlinear_fit_reduced_chi2', - 'ianderson_darling_normal', 'iinter_percentile_range_5', - 'ichi2', 'istetson_K', 'iweighted_mean', 'iduration', 'iotsu_mean_diff', 'iotsu_std_lower', 'iotsu_std_upper', - 'iotsu_lower_to_all_ratio', 'ilinear_fit_slope', 'ilinear_fit_slope_sigma', 'ilinear_fit_reduced_chi2', - 'zanderson_darling_normal', 'zinter_percentile_range_5', - 'zchi2', 'zstetson_K', 'zweighted_mean', 'zduration', 'zotsu_mean_diff', 'zotsu_std_lower', 'zotsu_std_upper', - 'zotsu_lower_to_all_ratio', 'zlinear_fit_slope', 'zlinear_fit_slope_sigma', 'zlinear_fit_reduced_chi2'] -} - -def make_features_header(filters: list, features: list, with_cost: bool = False) -> list: - """ - This function returns header list for given filters and features - - Parameters - ---------- - filters - filter values - features - feature values - with_cost - flag for adding cost values - - Returns - ------- - header - header list - """ - - header = [] - header.extend(['id', 'redshift', 'type', 'code', 'orig_sample', - #'queryable','last_rmag' - ] - ) # do we include queryable/last_rmag? headers sometimes have them sometimes don't - # also find where the 'with_cost' flag is used to make sure we apply there - if with_cost: - header.append('cost_4m', 'cost_8m') - for each_filter in filters: - for each_feature in features: - header.append(each_filter + each_feature) - return header - def read_file(file_path: str) -> list: """ diff --git a/src/resspect/time_domain_plasticc.py b/src/resspect/time_domain_plasticc.py index 8b12045a..5d5c15f3 100644 --- a/src/resspect/time_domain_plasticc.py +++ b/src/resspect/time_domain_plasticc.py @@ -22,17 +22,14 @@ import pandas as pd import progressbar -from resspect.lightcurves_utils import BAZIN_HEADERS +from resspect.filter_sets import FILTER_SETS from resspect.lightcurves_utils import get_query_flags from resspect.lightcurves_utils import maybe_create_directory from resspect.lightcurves_utils import PLASTICC_TARGET_TYPES from resspect.lightcurves_utils import read_plasticc_full_photometry_data from resspect.plugin_utils import fetch_feature_extractor_class - -FEATURE_EXTRACTOR_HEADERS_MAPPING = { - "Bazin": BAZIN_HEADERS -} +SUPPORTED_FEATURE_EXTRACTORS = ["Bazin"] class PLAsTiCCPhotometry: @@ -124,13 +121,17 @@ def _set_header(self, get_cost: bool = False, feature_extractor: str = 'Bazin'): Separate by 1 space. Default option uses header for Bazin features file. """ - if feature_extractor not in FEATURE_EXTRACTOR_HEADERS_MAPPING: - raise ValueError('Only Bazin headers are supported') - self._header = FEATURE_EXTRACTOR_HEADERS_MAPPING[ - feature_extractor]['plasticc_header'] - if get_cost: - self._header = FEATURE_EXTRACTOR_HEADERS_MAPPING[ - feature_extractor]['plasticc_header_with_cost'] + if feature_extractor not in SUPPORTED_FEATURE_EXTRACTORS: + raise ValueError(f'Only the following feature extractors are supported: {", ".join(SUPPORTED_FEATURE_EXTRACTORS)}') + + feature_extractor_class = fetch_feature_extractor_class(feature_extractor) + self._header = feature_extractor_class.get_feature_header( + filters=FILTER_SETS['LSST'], + with_cost=get_cost, + override_primary_columns=['id', 'redshift', 'type', 'code', 'sample'], + with_queryable=True, + with_last_rmag=True + ) def create_daily_file(self, output_dir: str, day: int, feature_extractor: str = 'Bazin', get_cost: bool = False): diff --git a/src/resspect/time_domain_snpcc.py b/src/resspect/time_domain_snpcc.py index 6bb5adda..4d02f74e 100644 --- a/src/resspect/time_domain_snpcc.py +++ b/src/resspect/time_domain_snpcc.py @@ -14,8 +14,7 @@ import os from itertools import repeat -from resspect.lightcurves_utils import BAZIN_HEADERS -from resspect.lightcurves_utils import MALANCHEV_HEADERS +from resspect.filter_sets import FILTER_SETS from resspect.lightcurves_utils import get_files_list from resspect.lightcurves_utils import get_query_flags from resspect.lightcurves_utils import maybe_create_directory @@ -26,10 +25,7 @@ __all__ = ['SNPCCPhotometry'] -FEATURE_EXTRACTOR_HEADERS_MAPPING = { - "Bazin": BAZIN_HEADERS, - "Malanchev": MALANCHEV_HEADERS -} +SUPPORTED_FEATURE_EXTRACTORS = ["Bazin", "Malanchev"] class SNPCCPhotometry: @@ -100,15 +96,19 @@ def create_daily_file(self, output_dir: str, self._features_file_name = os.path.join( output_dir, 'day_' + str(day) + '.csv') logging.info('Creating features file') - with open(self._features_file_name, 'w') as features_file: - if feature_extractor not in FEATURE_EXTRACTOR_HEADERS_MAPPING: - raise ValueError('Only Bazin and Malanchev headers are supported') - self._header = FEATURE_EXTRACTOR_HEADERS_MAPPING[ - feature_extractor]['snpcc_header'] - if get_cost: - self._header = FEATURE_EXTRACTOR_HEADERS_MAPPING[ - feature_extractor]['snpcc_header_with_cost'] + if feature_extractor not in SUPPORTED_FEATURE_EXTRACTORS: + raise ValueError(f'Only the following feature extractors are supported: {", ".join(SUPPORTED_FEATURE_EXTRACTORS)}') + + feature_extractor_class = fetch_feature_extractor_class(feature_extractor) + self._header = feature_extractor_class.get_feature_header( + filters=FILTER_SETS['SNPCC'], + with_cost=get_cost, + with_queryable=True, + with_last_rmag=True + ) + + with open(self._features_file_name, 'w') as features_file: features_file.write(','.join(self._header) + '\n') def _verify_telescope_names(self, telescope_names: list, get_cost: bool): diff --git a/tests/resspect/test_feature_extractor_utils.py b/tests/resspect/test_feature_extractor_utils.py new file mode 100644 index 00000000..fef94255 --- /dev/null +++ b/tests/resspect/test_feature_extractor_utils.py @@ -0,0 +1,58 @@ +from resspect.feature_extractors.feature_extractor_utils import create_filter_feature_names, make_features_header + +def test_create_filter_feature_names(): + filters = ['g', 'r'] + features = ['A', 'B'] + assert create_filter_feature_names(filters, features) == ['gA', 'gB', 'rA', 'rB'] + +def test_create_filter_feature_names_empty(): + filters = [] + features = ['A', 'B'] + assert create_filter_feature_names(filters, features) == [] + +def test_make_features_header(): + filters = ['g', 'r', 'i', 'z'] + features = ['A', 'B'] + assert make_features_header(filters, features) == ['id', 'redshift', 'type', 'code', 'orig_sample', 'gA', 'gB', 'rA', 'rB', 'iA', 'iB', 'zA', 'zB'] + +def test_make_features_header_with_cost(): + filters = ['g', 'r', 'i', 'z'] + features = ['A', 'B'] + assert make_features_header(filters, features, with_cost=True) == ['id', 'redshift', 'type', 'code', 'orig_sample', 'cost_4m', 'cost_8m', 'gA', 'gB', 'rA', 'rB', 'iA', 'iB', 'zA', 'zB'] + +def test_make_features_header_with_queryable(): + filters = ['g', 'r', 'i', 'z'] + features = ['A', 'B'] + assert make_features_header(filters, features, with_queryable=True) == ['id', 'redshift', 'type', 'code', 'orig_sample', 'queryable', 'gA', 'gB', 'rA', 'rB', 'iA', 'iB', 'zA', 'zB'] + +def test_make_features_header_with_last_rmag(): + filters = ['g', 'r', 'i', 'z'] + features = ['A', 'B'] + assert make_features_header(filters, features, with_last_rmag=True) == ['id', 'redshift', 'type', 'code', 'orig_sample', 'last_rmag', 'gA', 'gB', 'rA', 'rB', 'iA', 'iB', 'zA', 'zB'] + +def test_make_features_header_with_override_primary_columns(): + filters = ['g', 'r', 'i', 'z'] + features = ['A', 'B'] + assert make_features_header(filters, features, override_primary_columns=['new_id', 'new_redshift', 'new_type', 'new_code', 'new_orig_sample']) == ['new_id', 'new_redshift', 'new_type', 'new_code', 'new_orig_sample', 'gA', 'gB', 'rA', 'rB', 'iA', 'iB', 'zA', 'zB'] + +def test_make_features_header_with_all_flags(): + filters = ['g', 'r', 'i', 'z'] + features = ['A', 'B'] + + expected = [ + 'new_id', 'new_redshift', 'new_type', 'new_code', 'new_orig_sample', + 'queryable', 'last_rmag', + 'cost_4m', 'cost_8m', + 'gA', 'gB', 'rA', 'rB', 'iA', 'iB', 'zA', 'zB' + ] + + result = make_features_header( + filters, + features, + with_cost=True, + with_queryable=True, + with_last_rmag=True, + override_primary_columns=['new_id', 'new_redshift', 'new_type', 'new_code', 'new_orig_sample'] + ) + + assert result == expected From 0d505708e724e0958fb6fa1d0c51dacae8f75bb4 Mon Sep 17 00:00:00 2001 From: Drew Oldag Date: Wed, 6 Nov 2024 16:17:37 -0800 Subject: [PATCH 2/5] Refactoring the section of database.py that builds self.feature_names and self.metadata_names. --- src/resspect/database.py | 41 ++++++------------- .../feature_extractor_utils.py | 1 - .../feature_extractors/light_curve.py | 23 ++++++++++- 3 files changed, 34 insertions(+), 31 deletions(-) diff --git a/src/resspect/database.py b/src/resspect/database.py index f242883b..fcd99722 100644 --- a/src/resspect/database.py +++ b/src/resspect/database.py @@ -25,7 +25,6 @@ fetch_query_strategy_class ) from resspect.filter_sets import FILTER_SETS -from resspect.feature_extractors.feature_extractor_utils import create_filter_feature_names __all__ = ['DataBase'] @@ -240,6 +239,10 @@ def load_features_from_file(self, path_to_features_file: str, screen=False, "Bump", or "Malanchev". Default is "Bazin". """ + if survey not in ['DES', 'LSST']: + raise ValueError('Only "DES" and "LSST" filters are ' + \ + 'implemented at this point!') + # read matrix with features if '.tar.gz' in path_to_features_file: tar = tarfile.open(path_to_features_file, 'r:gz') @@ -257,41 +260,21 @@ def load_features_from_file(self, path_to_features_file: str, screen=False, # Get the list of feature names from the feature extractor class feature_extractor_class = fetch_feature_extractor_class(feature_extractor) - feature_names = feature_extractor_class.feature_names # Create the filter-feature names based on the survey. survey_filters = FILTER_SETS[survey] - self.features_names = create_filter_feature_names(survey_filters, feature_names) + self.features_names = feature_extractor_class.get_features(survey_filters) - if survey == 'DES': - self.metadata_names = ['id', 'redshift', 'type', 'code', - 'orig_sample', 'queryable'] - - if 'last_rmag' in data.keys(): - self.metadata_names.append('last_rmag') + self.metadata_names = ['id', 'redshift', 'type', 'code', 'orig_sample', 'queryable'] + if 'objid' in data.keys(): + self.metadata_names = ['objid', 'redshift', 'type', 'code', 'orig_sample', 'queryable'] - for name in self.telescope_names: - if 'cost_' + name in data.keys(): - self.metadata_names = self.metadata_names + ['cost_' + name] - - elif survey == 'LSST': - if 'objid' in data.keys(): - self.metadata_names = ['objid', 'redshift', 'type', 'code', - 'orig_sample', 'queryable'] - elif 'id' in data.keys(): - self.metadata_names = ['id', 'redshift', 'type', 'code', - 'orig_sample', 'queryable'] - - if 'last_rmag' in data.keys(): + if 'last_rmag' in data.keys(): self.metadata_names.append('last_rmag') - for name in self.telescope_names: - if 'cost_' + name in data.keys(): - self.metadata_names = self.metadata_names + ['cost_' + name] - - else: - raise ValueError('Only "DES" and "LSST" filters are ' + \ - 'implemented at this point!') + for name in self.telescope_names: + if 'cost_' + name in data.keys(): + self.metadata_names = self.metadata_names + ['cost_' + name] if sample == None: self.features = data[self.features_names].values diff --git a/src/resspect/feature_extractors/feature_extractor_utils.py b/src/resspect/feature_extractors/feature_extractor_utils.py index 3b3f07d2..80fecde4 100644 --- a/src/resspect/feature_extractors/feature_extractor_utils.py +++ b/src/resspect/feature_extractors/feature_extractor_utils.py @@ -38,7 +38,6 @@ def make_features_header( if kwargs.get('override_primary_columns', False): header = kwargs.get('override_primary_columns') - if kwargs.get('with_queryable', False): header.append('queryable') if kwargs.get('with_last_rmag', False): diff --git a/src/resspect/feature_extractors/light_curve.py b/src/resspect/feature_extractors/light_curve.py index 22d1a52c..301f8c07 100644 --- a/src/resspect/feature_extractors/light_curve.py +++ b/src/resspect/feature_extractors/light_curve.py @@ -25,7 +25,10 @@ read_file, read_plasticc_full_photometry_data, ) -from resspect.feature_extractors.feature_extractor_utils import make_features_header +from resspect.feature_extractors.feature_extractor_utils import ( + create_filter_feature_names, + make_features_header, +) warnings.filterwarnings("ignore", category=RuntimeWarning) logging.basicConfig(level=logging.INFO) @@ -146,6 +149,24 @@ def fit_all(self): """ raise NotImplementedError() + + @classmethod + def get_features(cls, filters: list) -> list[str]: + """ + Returns the header for the features extracted for all filters, excludes + non-feature columns (i.e. id, redshift, type, code, orig_sample, ...) + + Parameters + ---------- + filters: list + List of broad band filters. + + Returns + ------- + list + """ + return create_filter_feature_names(filters, cls.feature_names) + @classmethod def get_feature_header(cls, filters: list, **kwargs) -> list[str]: """ From 4125946d5111208c9ca63cbb1ee240a6340d53ee Mon Sep 17 00:00:00 2001 From: Drew Oldag Date: Wed, 6 Nov 2024 16:26:48 -0800 Subject: [PATCH 3/5] Adding that all important new line at the end of the file. --- src/resspect/feature_extractors/feature_extractor_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/resspect/feature_extractors/feature_extractor_utils.py b/src/resspect/feature_extractors/feature_extractor_utils.py index 80fecde4..683946ac 100644 --- a/src/resspect/feature_extractors/feature_extractor_utils.py +++ b/src/resspect/feature_extractors/feature_extractor_utils.py @@ -70,4 +70,4 @@ def create_filter_feature_names(filters: List[str], features: List[str]) -> List List of filter-feature pairs. """ - return [''.join(pair) for pair in itertools.product(filters, features)] \ No newline at end of file + return [''.join(pair) for pair in itertools.product(filters, features)] From abe2d5b4e95bc701757cc81deb0e6f5853477b4f Mon Sep 17 00:00:00 2001 From: Michael Tauraso Date: Fri, 25 Oct 2024 15:14:03 -0700 Subject: [PATCH 4/5] Fixing up Dockerfile, adding github action to build image - Image installs resspect from source - Image builds on latest ubuntu avoiding conflicts with ubuntu's version of pip - Image sets up a data directory and a venv for resspect - Image does not yet work with the docker-compose file or the TOM docker-compose ecosystem. --- .github/workflows/docker-build.yml | 32 +++++++++++++++ Dockerfile | 64 ++++++++++++++++++++++++------ README.md | 2 +- 3 files changed, 85 insertions(+), 13 deletions(-) create mode 100644 .github/workflows/docker-build.yml diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml new file mode 100644 index 00000000..fbe73e43 --- /dev/null +++ b/.github/workflows/docker-build.yml @@ -0,0 +1,32 @@ + +name: Build Docker image +# Shamelessly cribbed from https://docs.docker.com/build/ci/github-actions/test-before-push/ +# with minor modifications +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +env: + TEST_TAG: user/app:test + LATEST_TAG: user/app:latest + +jobs: + docker: + runs-on: ubuntu-latest + steps: + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Build and export to Docker + uses: docker/build-push-action@v6 + with: + load: true + tags: ${{ env.TEST_TAG }} + build-args: BUILDKIT_CONTEXT_KEEP_GIT_DIR=1 + + # TODO actually have this do something with test data + - name: Test + run: | + docker run --rm ${{ env.TEST_TAG }} -c "fit_dataset --help" \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index f4baef73..976b7c01 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,27 +1,67 @@ FROM ubuntu -WORKDIR /resspect -ENV HOME / +ENV RESSPECT_DIR=/resspect +ENV RESSPECT_SRC=${RESSPECT_DIR}/resspect-src +ENV RESSPECT_VENV=${RESSPECT_DIR}/resspect-venv +ENV RESSPECT_VENV_BIN=${RESSPECT_VENV}/bin +ENV RESSPECT_WORK=${RESSPECT_DIR}/resspect-work + +WORKDIR ${RESSPECT_DIR} + +#ENV HOME=/ + RUN echo "entering resspect Dockerfile" -RUN apt-get update && \ +RUN apt-get update && \ apt-get -y upgrade && \ apt-get clean && \ - apt-get install -y python3 python3-pip postgresql-client && \ + apt-get install -y python3 python3-pip python3-venv postgresql-client git && \ rm -rf /var/lib/apt/lists/* RUN ln -s /usr/bin/python3 /usr/bin/python -RUN pip install --upgrade pip +# Copy over resspect source from local checkout +COPY . ${RESSPECT_SRC} + +# Create a venv for resspect +RUN python3 -m venv ${RESSPECT_VENV} -COPY pyproject.toml ./pyproject.toml -RUN pip install dephell[full] && \ - dephell deps convert --from=pyproject.toml --to=requirements.txt && \ - pip install -r requirements.txt && \ - pip uninstall -y dephell && \ - rm -rf /root/.cache/pip +# Use this venv for future python commands in this dockerfile +ENV PATH=${RESSPECT_VENV_BIN}:$PATH +# Activate the venv every time we log in. +RUN touch /root/.bashrc && echo "source ${RESSPECT_VENV_BIN}/activate" >> /root/.bashrc + +# Install RESSPECT and its dependencies within the virtual env. +# +# We inject a pretend version number via `git rev-parse HEAD` so that pip can find a version number for +# RESSPECT when this is built in github actions CI. Finding a version number depends deeply on available +# git metadata in the local checkout. +# +# When called from docker/build-push-action, buildkit manages its own separate checkout of the +# container source rather than using the checkout supplied by the github actions runner. +# +# While the docker/build-push-action can be configured to retain the .git directory using +# `build-args: BUILDKIT_CONTEXT_KEEP_GIT_DIR=1` -- and we rely on that below -- +# docker/build-push-action cannot be configured to download tags. +# +# Sadly, pip uses setuptools_scm to generate a version number for our package. In turn +# setuptools_scm relies on git-describe. git-describe relies on downloaded tags in the +# checkout to produce output which is processable by setuptools_scm, and therefore pip. +# These tags are simly not present when using docker/build-push-action. +# +# This approach has been adapted from the workaround in the setuptools_scm documentation here: +# https://setuptools-scm.readthedocs.io/en/latest/usage/#with-dockerpodman +# +# It is probably not appropriate for publishing docker containers because it does not fall back +# to the actual version number of the package during a release build. +RUN bash -c "SETUPTOOLS_SCM_PRETEND_VERSION_FOR_RESSPECT=0+$(cd ${RESSPECT_SRC} && git rev-parse HEAD) \ + pip install ${RESSPECT_SRC}" + +# Create a sample work dir for resspect +RUN mkdir -p ${RESSPECT_WORK}/results +RUN mkdir -p ${RESSPECT_WORK}/plots +RUN cp -r ${RESSPECT_SRC}/data ${RESSPECT_WORK} EXPOSE 8081 ENTRYPOINT ["bash"] - diff --git a/README.md b/README.md index 6d22d738..438735ca 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,6 @@ Navigate to the repository folder and do You can now install this package with: - (RESSPECT) >>> python setup.py install + (RESSPECT) >>> pip install -e . > You may choose to create your virtual environment within the folder of the repository. If you choose to do this, you must remember to exclude the virtual environment directory from version control using e.g., ``.gitignore``. From 91cf04ebcc052e31967e8d957a773c6b91cf7d2b Mon Sep 17 00:00:00 2001 From: Michael Tauraso Date: Tue, 29 Oct 2024 14:36:00 -0700 Subject: [PATCH 5/5] Make docker compose setup work with new Dockerfile And document how to connect to tom docker-compose ecosystem --- .gitignore | 3 +++ Dockerfile | 12 +++++++---- README.md | 52 +++++++++++++++++++++++++++++++++++++++++++++- docker-compose.yml | 20 +++++++++++------- 4 files changed, 75 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 235e6bd0..dd58fb55 100644 --- a/.gitignore +++ b/.gitignore @@ -150,3 +150,6 @@ _html/ .initialize_new_project.sh auxiliary_files/cosmo auxiliary_files/cosmo.hpp + +# Location for secrets used to log into TOM in a dev envornoment +secrets/ diff --git a/Dockerfile b/Dockerfile index 976b7c01..ff14e175 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,9 @@ FROM ubuntu ENV RESSPECT_DIR=/resspect + +# Note that this is where we copy resspect source at build time +# and also where docker-compose mounts the curent source directory in this container. ENV RESSPECT_SRC=${RESSPECT_DIR}/resspect-src ENV RESSPECT_VENV=${RESSPECT_DIR}/resspect-venv ENV RESSPECT_VENV_BIN=${RESSPECT_VENV}/bin @@ -8,9 +11,7 @@ ENV RESSPECT_WORK=${RESSPECT_DIR}/resspect-work WORKDIR ${RESSPECT_DIR} -#ENV HOME=/ - -RUN echo "entering resspect Dockerfile" +RUN echo "Entering resspect Dockerfile" RUN apt-get update && \ apt-get -y upgrade && \ @@ -20,7 +21,10 @@ RUN apt-get update && \ RUN ln -s /usr/bin/python3 /usr/bin/python -# Copy over resspect source from local checkout +# Copy over resspect source from local checkout so we can install dependencies +# and so the container will have a version of RESSPECT packaged when run standalone. +# +# If this line is changed, pleas refer to the bind mount in the compose file to ensure consistency. COPY . ${RESSPECT_SRC} # Create a venv for resspect diff --git a/README.md b/README.md index 438735ca..ee6900e8 100644 --- a/README.md +++ b/README.md @@ -82,4 +82,54 @@ You can now install this package with: (RESSPECT) >>> pip install -e . -> You may choose to create your virtual environment within the folder of the repository. If you choose to do this, you must remember to exclude the virtual environment directory from version control using e.g., ``.gitignore``. +> You may choose to create your virtual environment within the folder of the repository. If you choose to do this, you must remember to exclude the virtual environment directory from version control using e.g., ``.gitignore``. + +# Starting the docker environment + +The docker file in this repository can be run as a standalone environment for testing or developing RESSPECT, or in connection with tom. You will need to start by installing Docker Desktop for your chosen platform before you start. + +Note: These workflows have only been tested on macs; however the standalone docker image is built on Linux in in Github Actions CI. + +## Using standalone + +To use the container standalone first go into the root of the source directory, and build the container with: +``` +docker build . +``` + +You can run the container two ways. The first way will use the version of resspect from your local checkout, which is probably what you want for development. After the +container is built run: +``` +docker run -it --rm --mount type=bind,source=.,target=/resspect/resspect-src resspect +``` + +This will put you into a bash shell in the container with the venv for resspect already activated, and the current version of resspect in your source checkout installed. + +If you wish to use the version of resspect packaged at build time, simply omit `--mount type=bind,source=.,target=/resspect/resspect-src` from the command above. + +## Using with tom docker-compose setup + +First checkout tom and follow [TOM's docker compose setup instructions](https://github.com/LSSTDESC/tom_desc?tab=readme-ov-file#deploying-a-dev-environment-with-docker). You will need to load [ELAsTiCC2 data](https://github.com/LSSTDESC/tom_desc?tab=readme-ov-file#for-elasticc2) into your tom environment in order to work with RESSPECT. + +When you have finished that setup, go into the top level source directory and run these two commands: +``` +docker compose build +docker compose up -d +``` + +You will now have a docker container called `resspect` which you can run resspect from. The version of resspect in use will be that of your local git checkout. That same docker container will be on the network with your tom setup, so you can access +the tom docker container on port 8080. + +You can enter the `resspect` docker container to run commands with +``` +docker compose run resspect +``` + +From the resspect container you should be able to log into the tom server with: +``` +(resspect-venv) root@cd647ac7eca5:/resspect# python3 +Python 3.12.3 (main, Sep 11 2024, 14:17:37) [GCC 13.2.0] on linux +Type "help", "copyright", "credits" or "license" for more information. +>>> from resspect import tom_client as tc +>>> tc = tc.TomClient(url="http://tom:8080", username='admin', password='') +``` diff --git a/docker-compose.yml b/docker-compose.yml index beee8f96..1b34ea34 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,5 +1,3 @@ -version: '3' - services: resspect: tty: true @@ -11,11 +9,13 @@ services: - "db" volumes: - type: bind - source: ./resspect - target: /resspect + source: ./ + target: /resspect/resspect-src - type: bind source: ./secrets target: /secrets + bind: + create_host_path: true environment: - DB_USER=admin - DB_PASS=verysecurepassword @@ -35,11 +35,17 @@ services: - POSTGRES_DB=resspectdb - POSTGRES_DATA_DIR=/docker-entrypoint-initdb.d volumes: - - ./resspectdb.sql:/docker-entrypoint-initdb.d/resspectdb.sql - type: bind - source: ./resspect + source: ./resspectdb.sql + target: /docker-entrypoint-initdb.d/resspectdb.sql + - type: bind + source: ./src/resspect target: /resspect - + healthcheck: + test: ["CMD-SHELL", "pg_isready -U admin -d resspectdb"] + interval: 1s + timeout: 5s + retries: 10 networks: tom_desc_default: