Skip to content

Commit

Permalink
Merge branch 'main' into issue/63/LightCurve-stub-methods
Browse files Browse the repository at this point in the history
  • Loading branch information
drewoldag authored Oct 31, 2024
2 parents 3213bb6 + 04d3e6b commit 0d1fcc8
Show file tree
Hide file tree
Showing 24 changed files with 213 additions and 146 deletions.
6 changes: 3 additions & 3 deletions benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def time_feature_creation():
fit_snpcc(
path_to_data_dir=input_file_path,
features_file=str(features_file),
feature_extractor="malanchev"
feature_extractor="Malanchev"
)


Expand All @@ -38,7 +38,7 @@ def time_learn_loop(ml_model, strategy):
learn_loop(
LoopConfiguration(
nloops=25,
features_method="malanchev",
features_method="Malanchev",
classifier=ml_model,
strategy=strategy,
path_to_features=features_file,
Expand Down Expand Up @@ -67,7 +67,7 @@ def peakmem_learn_loop(ml_model):
learn_loop(
LoopConfiguration(
nloops=25,
features_method="malanchev",
features_method="Malanchev",
classifier=ml_model,
strategy="RandomSampling",
path_to_features=features_file,
Expand Down
6 changes: 3 additions & 3 deletions docs/learn_loop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ For start, we can load the feature information:
>>> path_to_features_file = 'results/Bazin.csv'
>>> data = DataBase()
>>> data.load_features(path_to_features_file, feature_extractor='bazin', screen=True)
>>> data.load_features(path_to_features_file, feature_extractor='Bazin', screen=True)
Loaded 21284 samples!
Notice that this data has some pre-determine separation between training and test sample:
Expand Down Expand Up @@ -84,7 +84,7 @@ In interactive mode, you must define the required variables and use the :py:mod:
>>> from resspect.learn_loop import learn_loop
>>> nloops = 1000 # number of iterations
>>> method = 'bazin' # only option in v1.0
>>> method = 'Bazin' # only option in v1.0
>>> ml = 'RandomForest' # classifier
>>> strategy = 'RandomSampling' # learning strategy
>>> input_file = 'results/Bazin.csv' # input features file
Expand Down Expand Up @@ -149,7 +149,7 @@ following the same algorithm described in `Ishida et al., 2019 <https://cosmosta
>>> classifier = 'RandomForest'
>>> n_estimators = 1000 # number of trees in the forest
>>> feature_extraction_method = 'bazin'
>>> feature_extraction_method = 'Bazin'
>>> screen = False # if True will print many things for debuging
>>> fname_pattern = ['day_', '.csv'] # pattern on filename where different days
# are stored
Expand Down
6 changes: 3 additions & 3 deletions docs/pre_processing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ For SNPCC using Bazin features:
>>> path_to_data_dir = user_cache_dir(appname='resspect/SIMGEN_PUBLIC_DES') # raw data directory
>>> features_file = 'results/Bazin.csv' # output file
>>> feature_extractor = 'bazin'
>>> feature_extractor = 'Bazin'
>>> fit_snpcc(path_to_data_dir=path_to_data_dir, features_file=features_file)
Expand All @@ -263,7 +263,7 @@ For SNPCC using Malanchev features:
>>> path_to_data_dir = user_cache_dir(appname='resspect/SIMGEN_PUBLIC_DES') # raw data directory
>>> features_file = 'results/Malanchev.csv' # output file
>>> feature_extractor = 'malanchev'
>>> feature_extractor = 'Malanchev'
>>> fit_snpcc(path_to_data_dir=path_to_data_dir, features_file=features_file)
Expand All @@ -279,7 +279,7 @@ For PLAsTiCC:
>>> path_photo_file = '~/plasticc_train_lightcurves.csv'
>>> path_header_file = '~/plasticc_train_metadata.csv.gz'
>>> output_file = 'results/PLAsTiCC_Bazin_train.dat'
>>> feature_extractor = 'bazin'
>>> feature_extractor = 'Bazin'
>>> sample = 'train'
Expand Down
2 changes: 1 addition & 1 deletion docs/prepare_time_domain.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ You can perform the entire analysis for one day of the survey using the `SNPCCPh
>>> day = 20
>>> queryable_criteria = 2
>>> get_cost = True
>>> feature_extractor = 'bazin'
>>> feature_extractor = 'Bazin'
>>> tel_sizes=[4, 8]
>>> tel_names = ['4m', '8m']
>>> spec_SNR = 10
Expand Down
2 changes: 2 additions & 0 deletions src/resspect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from .query_budget_strategies import *
from .bump import *
from .feature_extractors.malanchev import *
from .feature_extractors.bazin import *
from .feature_extractors.bump import *
from .plugin_utils import *

import importlib.metadata
Expand Down
4 changes: 2 additions & 2 deletions src/resspect/build_snpcc_canonical.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def get_meta_data_from_features(path_to_features: str,
path_to_features: str
Complete path to Bazin features files
features_method: str (optional)
Method for feature extraction. Only 'bazin' is implemented.
Method for feature extraction.
"""
data = DataBase()
data.load_features(path_to_file=path_to_features, feature_extractor=features_method,
Expand All @@ -330,7 +330,7 @@ def get_meta_data_from_features(path_to_features: str,
def build_snpcc_canonical(path_to_raw_data: str, path_to_features: str,
output_canonical_file: str, output_info_file='',
compute=True, save=True, input_info_file='',
features_method='bazin', screen=False,
features_method='Bazin', screen=False,
number_of_neighbors=1):
"""Build canonical sample for SNPCC data.
Expand Down
42 changes: 18 additions & 24 deletions src/resspect/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
import tarfile

from resspect.classifiers import *
from resspect.feature_extractors.bazin import BazinFeatureExtractor
from resspect.feature_extractors.bump import BumpFeatureExtractor
from resspect.feature_extractors.malanchev import MalanchevFeatureExtractor
from resspect.feature_extractors.light_curve import FEATURE_EXTRACTOR_REGISTRY
from resspect.query_strategies import *
from resspect.query_budget_strategies import *
from resspect.metrics import get_snpcc_metric
Expand All @@ -32,13 +30,6 @@
__all__ = ['DataBase']


FEATURE_EXTRACTOR_MAPPING = {
"bazin": BazinFeatureExtractor,
"bump": BumpFeatureExtractor,
"malanchev": MalanchevFeatureExtractor
}


class DataBase:
"""DataBase object, upon which the active learning loop is performed.
Expand Down Expand Up @@ -152,7 +143,7 @@ class DataBase:
Initiate the DataBase object and load the data.
>>> data = DataBase()
>>> data.load_features(path_to_bazin_file, method='bazin')
>>> data.load_features(path_to_bazin_file, method='Bazin')
Separate training and test samples and classify
Expand Down Expand Up @@ -223,7 +214,7 @@ def __init__(self):
self.validation_prob = np.array([])

def load_features_from_file(self, path_to_features_file: str, screen=False,
survey='DES', sample=None, feature_extractor: str='bazin'):
survey='DES', sample=None, feature_extractor: str='Bazin'):

"""Load features from file.
Expand All @@ -245,8 +236,8 @@ def load_features_from_file(self, path_to_features_file: str, screen=False,
else, read independent files for 'train' and 'test'.
Default is None.
feature_extractor: str (optional)
Function used for feature extraction. Options are "bazin",
"bump", or "malanchev". Default is "bump".
Function used for feature extraction. Options are "Bazin",
"Bump", or "Malanchev". Default is "Bazin".
"""

# read matrix with features
Expand All @@ -264,19 +255,21 @@ def load_features_from_file(self, path_to_features_file: str, screen=False,
if 'queryable' not in data.keys():
data['queryable'] = [True for i in range(data.shape[0])]


#! Make this part work better with the different feature extractors
# list of features to use
if survey == 'DES':
if feature_extractor == "bazin":
if feature_extractor == "Bazin":
self.features_names = ['gA', 'gB', 'gt0', 'gtfall', 'gtrise', 'rA',
'rB', 'rt0', 'rtfall', 'rtrise', 'iA', 'iB',
'it0', 'itfall', 'itrise', 'zA', 'zB', 'zt0',
'ztfall', 'ztrise']
elif feature_extractor == 'bump':
elif feature_extractor == 'Bump':
self.features_names = ['gp1', 'gp2', 'gp3', 'gmax_flux',
'rp1', 'rp2', 'rp3', 'rmax_flux',
'ip1', 'ip2', 'ip3', 'imax_flux',
'zp1', 'zp2', 'zp3', 'zmax_flux']
elif feature_extractor == 'malanchev':
elif feature_extractor == 'Malanchev':
self.features_names = ['ganderson_darling_normal','ginter_percentile_range_5',
'gchi2','gstetson_K','gweighted_mean','gduration',
'gotsu_mean_diff','gotsu_std_lower', 'gotsu_std_upper',
Expand Down Expand Up @@ -309,14 +302,14 @@ def load_features_from_file(self, path_to_features_file: str, screen=False,
self.metadata_names = self.metadata_names + ['cost_' + name]

elif survey == 'LSST':
if feature_extractor == "bazin":
if feature_extractor == "Bazin":
self.features_names = ['uA', 'uB', 'ut0', 'utfall', 'utrise',
'gA', 'gB', 'gt0', 'gtfall', 'gtrise',
'rA', 'rB', 'rt0', 'rtfall', 'rtrise',
'iA', 'iB', 'it0', 'itfall', 'itrise',
'zA', 'zB', 'zt0', 'ztfall', 'ztrise',
'YA', 'YB', 'Yt0', 'Ytfall', 'Ytrise']
elif feature_extractor == "malanchev":
elif feature_extractor == "Malanchev":
self.features_names = ['uanderson_darling_normal','uinter_percentile_range_5',
'uchi2','ustetson_K','uweighted_mean','uduration',
'uotsu_mean_diff','uotsu_std_lower', 'uotsu_std_upper',
Expand Down Expand Up @@ -474,7 +467,7 @@ def load_photometry_features(self, path_to_photometry_file: str,
print('\n Loaded ', self.test_metadata.shape[0],
' samples! \n')

def load_features(self, path_to_file: str, feature_extractor: str ='bazin',
def load_features(self, path_to_file: str, feature_extractor: str ='Bazin',
screen=False, survey='DES', sample=None ):
"""Load features according to the chosen feature extraction method.
Expand All @@ -487,8 +480,8 @@ def load_features(self, path_to_file: str, feature_extractor: str ='bazin',
Complete path to features file.
feature_extractor: str (optional)
Feature extraction method. The current implementation only
accepts =='bazin', 'bump', 'malanchev', or 'photometry'.
Default is 'bazin'.
accepts =='Bazin', 'Bump', 'Malanchev', or 'photometry'.
Default is 'Bazin'.
screen: bool (optional)
If True, print on screen number of light curves processed.
Default is False.
Expand All @@ -504,12 +497,13 @@ def load_features(self, path_to_file: str, feature_extractor: str ='bazin',
if feature_extractor == "photometry":
self.load_photometry_features(path_to_file, screen=screen,
survey=survey, sample=sample)
elif feature_extractor in FEATURE_EXTRACTOR_MAPPING:
elif feature_extractor in FEATURE_EXTRACTOR_REGISTRY:
self.load_features_from_file(
path_to_file, screen=screen, survey=survey,
sample=sample, feature_extractor=feature_extractor)
else:
raise ValueError('Only bazin, bump, malanchev, or photometry features are implemented!'
feature_extractors = ', '.join(FEATURE_EXTRACTOR_REGISTRY.keys())
raise ValueError(f'Only {feature_extractors} or photometry features are implemented!'
'\n Feel free to add other options.')

def load_plasticc_mjd(self, path_to_data_dir):
Expand Down
6 changes: 2 additions & 4 deletions src/resspect/feature_extractors/bazin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@
from resspect.bazin import fit_scipy
from resspect.feature_extractors.light_curve import LightCurve

__all__ = ['BazinFeatureExtractor']


class BazinFeatureExtractor(LightCurve):
class Bazin(LightCurve):
def __init__(self):
super().__init__()
self.features_names = ['a', 'b', 't0', 'tfall', 'trise']
self.features_names = ['A', 'B', 't0', 'tfall', 'trise']

def evaluate(self, time: np.array) -> dict:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/resspect/feature_extractors/bump.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from resspect.feature_extractors.light_curve import LightCurve


class BumpFeatureExtractor(LightCurve):
class Bump(LightCurve):
def __init__(self):
super().__init__()
self.features_names = ['p1', 'p2', 'p3', 'time_shift', 'max_flux']
Expand Down
14 changes: 10 additions & 4 deletions src/resspect/feature_extractors/light_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@
warnings.filterwarnings("ignore", category=RuntimeWarning)
logging.basicConfig(level=logging.INFO)

__all__ = ['LightCurve']
__all__ = ['LightCurve', 'FEATURE_EXTRACTOR_REGISTRY']

FEATURE_EXTRACTOR_REGISTRY = {}

class LightCurve:
""" Light Curve object, holding meta and photometric data.
Expand Down Expand Up @@ -89,8 +90,6 @@ class LightCurve:
Check if this light can be queried in a given day.
conv_flux_mag(flux: np.array)
Convert positive flux into magnitude.
evaluate_bazin(param: list, time: np.array) -> np.array
Evaluate the Bazin function given parameter values.
load_snpcc_lc(path_to_data: str)
Reads header and photometric information for 1 light curve.
load_plasticc_lc(photo_file: str, snid: int)
Expand Down Expand Up @@ -121,6 +120,13 @@ def __init__(self):
self.sncode = 0
self.sntype = ' '

def __init_subclass__(cls):
"""Register all subclasses of LightCurve in the FEATURE_EXTRACTOR_REGISTRY."""
if cls.__name__ in FEATURE_EXTRACTOR_REGISTRY:
raise ValueError(f"Duplicate feature extractor name: {cls.__name__}")

FEATURE_EXTRACTOR_REGISTRY[cls.__name__] = cls

def fit(self, band: str) -> np.ndarray:
"""
Extract features for one filter.
Expand Down Expand Up @@ -329,7 +335,7 @@ def check_queryable(self, mjd: float, filter_lim: float, criteria: int =1,
self.last_mag = self.conv_flux_mag([fitted_flux])[0]

else:
raise ValueError('Only "Bazin" and "malanchev" features are implemented!')
raise ValueError('Only "Bazin" and "Malanchev" features are implemented!')

elif sum(surv_flag):
raise ValueError('Criteria needs to be "1" or "2". \n ' + \
Expand Down
3 changes: 1 addition & 2 deletions src/resspect/feature_extractors/malanchev.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
import light_curve as licu
from resspect.feature_extractors.light_curve import LightCurve

__all__ = ['MalanchevFeatureExtractor']

class MalanchevFeatureExtractor(LightCurve):
class Malanchev(LightCurve):
def __init__(self):
super().__init__()
self.features_names = ['anderson_darling_normal',
Expand Down
Loading

0 comments on commit 0d1fcc8

Please sign in to comment.