Skip to content

Commit 0d1fcc8

Browse files
authored
Merge branch 'main' into issue/63/LightCurve-stub-methods
2 parents 3213bb6 + 04d3e6b commit 0d1fcc8

24 files changed

+213
-146
lines changed

benchmarks/benchmarks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def time_feature_creation():
2424
fit_snpcc(
2525
path_to_data_dir=input_file_path,
2626
features_file=str(features_file),
27-
feature_extractor="malanchev"
27+
feature_extractor="Malanchev"
2828
)
2929

3030

@@ -38,7 +38,7 @@ def time_learn_loop(ml_model, strategy):
3838
learn_loop(
3939
LoopConfiguration(
4040
nloops=25,
41-
features_method="malanchev",
41+
features_method="Malanchev",
4242
classifier=ml_model,
4343
strategy=strategy,
4444
path_to_features=features_file,
@@ -67,7 +67,7 @@ def peakmem_learn_loop(ml_model):
6767
learn_loop(
6868
LoopConfiguration(
6969
nloops=25,
70-
features_method="malanchev",
70+
features_method="Malanchev",
7171
classifier=ml_model,
7272
strategy="RandomSampling",
7373
path_to_features=features_file,

docs/learn_loop.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ For start, we can load the feature information:
1919
>>> path_to_features_file = 'results/Bazin.csv'
2020
2121
>>> data = DataBase()
22-
>>> data.load_features(path_to_features_file, feature_extractor='bazin', screen=True)
22+
>>> data.load_features(path_to_features_file, feature_extractor='Bazin', screen=True)
2323
Loaded 21284 samples!
2424
2525
Notice that this data has some pre-determine separation between training and test sample:
@@ -84,7 +84,7 @@ In interactive mode, you must define the required variables and use the :py:mod:
8484
>>> from resspect.learn_loop import learn_loop
8585
8686
>>> nloops = 1000 # number of iterations
87-
>>> method = 'bazin' # only option in v1.0
87+
>>> method = 'Bazin' # only option in v1.0
8888
>>> ml = 'RandomForest' # classifier
8989
>>> strategy = 'RandomSampling' # learning strategy
9090
>>> input_file = 'results/Bazin.csv' # input features file
@@ -149,7 +149,7 @@ following the same algorithm described in `Ishida et al., 2019 <https://cosmosta
149149
>>> classifier = 'RandomForest'
150150
>>> n_estimators = 1000 # number of trees in the forest
151151
152-
>>> feature_extraction_method = 'bazin'
152+
>>> feature_extraction_method = 'Bazin'
153153
>>> screen = False # if True will print many things for debuging
154154
>>> fname_pattern = ['day_', '.csv'] # pattern on filename where different days
155155
# are stored

docs/pre_processing.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ For SNPCC using Bazin features:
248248
249249
>>> path_to_data_dir = user_cache_dir(appname='resspect/SIMGEN_PUBLIC_DES') # raw data directory
250250
>>> features_file = 'results/Bazin.csv' # output file
251-
>>> feature_extractor = 'bazin'
251+
>>> feature_extractor = 'Bazin'
252252
253253
>>> fit_snpcc(path_to_data_dir=path_to_data_dir, features_file=features_file)
254254
@@ -263,7 +263,7 @@ For SNPCC using Malanchev features:
263263
264264
>>> path_to_data_dir = user_cache_dir(appname='resspect/SIMGEN_PUBLIC_DES') # raw data directory
265265
>>> features_file = 'results/Malanchev.csv' # output file
266-
>>> feature_extractor = 'malanchev'
266+
>>> feature_extractor = 'Malanchev'
267267
268268
>>> fit_snpcc(path_to_data_dir=path_to_data_dir, features_file=features_file)
269269
@@ -279,7 +279,7 @@ For PLAsTiCC:
279279
>>> path_photo_file = '~/plasticc_train_lightcurves.csv'
280280
>>> path_header_file = '~/plasticc_train_metadata.csv.gz'
281281
>>> output_file = 'results/PLAsTiCC_Bazin_train.dat'
282-
>>> feature_extractor = 'bazin'
282+
>>> feature_extractor = 'Bazin'
283283
284284
>>> sample = 'train'
285285

docs/prepare_time_domain.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ You can perform the entire analysis for one day of the survey using the `SNPCCPh
4848
>>> day = 20
4949
>>> queryable_criteria = 2
5050
>>> get_cost = True
51-
>>> feature_extractor = 'bazin'
51+
>>> feature_extractor = 'Bazin'
5252
>>> tel_sizes=[4, 8]
5353
>>> tel_names = ['4m', '8m']
5454
>>> spec_SNR = 10

src/resspect/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
from .query_budget_strategies import *
4646
from .bump import *
4747
from .feature_extractors.malanchev import *
48+
from .feature_extractors.bazin import *
49+
from .feature_extractors.bump import *
4850
from .plugin_utils import *
4951

5052
import importlib.metadata

src/resspect/build_snpcc_canonical.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def get_meta_data_from_features(path_to_features: str,
319319
path_to_features: str
320320
Complete path to Bazin features files
321321
features_method: str (optional)
322-
Method for feature extraction. Only 'bazin' is implemented.
322+
Method for feature extraction.
323323
"""
324324
data = DataBase()
325325
data.load_features(path_to_file=path_to_features, feature_extractor=features_method,
@@ -330,7 +330,7 @@ def get_meta_data_from_features(path_to_features: str,
330330
def build_snpcc_canonical(path_to_raw_data: str, path_to_features: str,
331331
output_canonical_file: str, output_info_file='',
332332
compute=True, save=True, input_info_file='',
333-
features_method='bazin', screen=False,
333+
features_method='Bazin', screen=False,
334334
number_of_neighbors=1):
335335
"""Build canonical sample for SNPCC data.
336336

src/resspect/database.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
import tarfile
2222

2323
from resspect.classifiers import *
24-
from resspect.feature_extractors.bazin import BazinFeatureExtractor
25-
from resspect.feature_extractors.bump import BumpFeatureExtractor
26-
from resspect.feature_extractors.malanchev import MalanchevFeatureExtractor
24+
from resspect.feature_extractors.light_curve import FEATURE_EXTRACTOR_REGISTRY
2725
from resspect.query_strategies import *
2826
from resspect.query_budget_strategies import *
2927
from resspect.metrics import get_snpcc_metric
@@ -32,13 +30,6 @@
3230
__all__ = ['DataBase']
3331

3432

35-
FEATURE_EXTRACTOR_MAPPING = {
36-
"bazin": BazinFeatureExtractor,
37-
"bump": BumpFeatureExtractor,
38-
"malanchev": MalanchevFeatureExtractor
39-
}
40-
41-
4233
class DataBase:
4334
"""DataBase object, upon which the active learning loop is performed.
4435
@@ -152,7 +143,7 @@ class DataBase:
152143
153144
Initiate the DataBase object and load the data.
154145
>>> data = DataBase()
155-
>>> data.load_features(path_to_bazin_file, method='bazin')
146+
>>> data.load_features(path_to_bazin_file, method='Bazin')
156147
157148
Separate training and test samples and classify
158149
@@ -223,7 +214,7 @@ def __init__(self):
223214
self.validation_prob = np.array([])
224215

225216
def load_features_from_file(self, path_to_features_file: str, screen=False,
226-
survey='DES', sample=None, feature_extractor: str='bazin'):
217+
survey='DES', sample=None, feature_extractor: str='Bazin'):
227218

228219
"""Load features from file.
229220
@@ -245,8 +236,8 @@ def load_features_from_file(self, path_to_features_file: str, screen=False,
245236
else, read independent files for 'train' and 'test'.
246237
Default is None.
247238
feature_extractor: str (optional)
248-
Function used for feature extraction. Options are "bazin",
249-
"bump", or "malanchev". Default is "bump".
239+
Function used for feature extraction. Options are "Bazin",
240+
"Bump", or "Malanchev". Default is "Bazin".
250241
"""
251242

252243
# read matrix with features
@@ -264,19 +255,21 @@ def load_features_from_file(self, path_to_features_file: str, screen=False,
264255
if 'queryable' not in data.keys():
265256
data['queryable'] = [True for i in range(data.shape[0])]
266257

258+
259+
#! Make this part work better with the different feature extractors
267260
# list of features to use
268261
if survey == 'DES':
269-
if feature_extractor == "bazin":
262+
if feature_extractor == "Bazin":
270263
self.features_names = ['gA', 'gB', 'gt0', 'gtfall', 'gtrise', 'rA',
271264
'rB', 'rt0', 'rtfall', 'rtrise', 'iA', 'iB',
272265
'it0', 'itfall', 'itrise', 'zA', 'zB', 'zt0',
273266
'ztfall', 'ztrise']
274-
elif feature_extractor == 'bump':
267+
elif feature_extractor == 'Bump':
275268
self.features_names = ['gp1', 'gp2', 'gp3', 'gmax_flux',
276269
'rp1', 'rp2', 'rp3', 'rmax_flux',
277270
'ip1', 'ip2', 'ip3', 'imax_flux',
278271
'zp1', 'zp2', 'zp3', 'zmax_flux']
279-
elif feature_extractor == 'malanchev':
272+
elif feature_extractor == 'Malanchev':
280273
self.features_names = ['ganderson_darling_normal','ginter_percentile_range_5',
281274
'gchi2','gstetson_K','gweighted_mean','gduration',
282275
'gotsu_mean_diff','gotsu_std_lower', 'gotsu_std_upper',
@@ -309,14 +302,14 @@ def load_features_from_file(self, path_to_features_file: str, screen=False,
309302
self.metadata_names = self.metadata_names + ['cost_' + name]
310303

311304
elif survey == 'LSST':
312-
if feature_extractor == "bazin":
305+
if feature_extractor == "Bazin":
313306
self.features_names = ['uA', 'uB', 'ut0', 'utfall', 'utrise',
314307
'gA', 'gB', 'gt0', 'gtfall', 'gtrise',
315308
'rA', 'rB', 'rt0', 'rtfall', 'rtrise',
316309
'iA', 'iB', 'it0', 'itfall', 'itrise',
317310
'zA', 'zB', 'zt0', 'ztfall', 'ztrise',
318311
'YA', 'YB', 'Yt0', 'Ytfall', 'Ytrise']
319-
elif feature_extractor == "malanchev":
312+
elif feature_extractor == "Malanchev":
320313
self.features_names = ['uanderson_darling_normal','uinter_percentile_range_5',
321314
'uchi2','ustetson_K','uweighted_mean','uduration',
322315
'uotsu_mean_diff','uotsu_std_lower', 'uotsu_std_upper',
@@ -474,7 +467,7 @@ def load_photometry_features(self, path_to_photometry_file: str,
474467
print('\n Loaded ', self.test_metadata.shape[0],
475468
' samples! \n')
476469

477-
def load_features(self, path_to_file: str, feature_extractor: str ='bazin',
470+
def load_features(self, path_to_file: str, feature_extractor: str ='Bazin',
478471
screen=False, survey='DES', sample=None ):
479472
"""Load features according to the chosen feature extraction method.
480473
@@ -487,8 +480,8 @@ def load_features(self, path_to_file: str, feature_extractor: str ='bazin',
487480
Complete path to features file.
488481
feature_extractor: str (optional)
489482
Feature extraction method. The current implementation only
490-
accepts =='bazin', 'bump', 'malanchev', or 'photometry'.
491-
Default is 'bazin'.
483+
accepts =='Bazin', 'Bump', 'Malanchev', or 'photometry'.
484+
Default is 'Bazin'.
492485
screen: bool (optional)
493486
If True, print on screen number of light curves processed.
494487
Default is False.
@@ -504,12 +497,13 @@ def load_features(self, path_to_file: str, feature_extractor: str ='bazin',
504497
if feature_extractor == "photometry":
505498
self.load_photometry_features(path_to_file, screen=screen,
506499
survey=survey, sample=sample)
507-
elif feature_extractor in FEATURE_EXTRACTOR_MAPPING:
500+
elif feature_extractor in FEATURE_EXTRACTOR_REGISTRY:
508501
self.load_features_from_file(
509502
path_to_file, screen=screen, survey=survey,
510503
sample=sample, feature_extractor=feature_extractor)
511504
else:
512-
raise ValueError('Only bazin, bump, malanchev, or photometry features are implemented!'
505+
feature_extractors = ', '.join(FEATURE_EXTRACTOR_REGISTRY.keys())
506+
raise ValueError(f'Only {feature_extractors} or photometry features are implemented!'
513507
'\n Feel free to add other options.')
514508

515509
def load_plasticc_mjd(self, path_to_data_dir):

src/resspect/feature_extractors/bazin.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,11 @@
99
from resspect.bazin import fit_scipy
1010
from resspect.feature_extractors.light_curve import LightCurve
1111

12-
__all__ = ['BazinFeatureExtractor']
1312

14-
15-
class BazinFeatureExtractor(LightCurve):
13+
class Bazin(LightCurve):
1614
def __init__(self):
1715
super().__init__()
18-
self.features_names = ['a', 'b', 't0', 'tfall', 'trise']
16+
self.features_names = ['A', 'B', 't0', 'tfall', 'trise']
1917

2018
def evaluate(self, time: np.array) -> dict:
2119
"""

src/resspect/feature_extractors/bump.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from resspect.feature_extractors.light_curve import LightCurve
1111

1212

13-
class BumpFeatureExtractor(LightCurve):
13+
class Bump(LightCurve):
1414
def __init__(self):
1515
super().__init__()
1616
self.features_names = ['p1', 'p2', 'p3', 'time_shift', 'max_flux']

src/resspect/feature_extractors/light_curve.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@
3434
warnings.filterwarnings("ignore", category=RuntimeWarning)
3535
logging.basicConfig(level=logging.INFO)
3636

37-
__all__ = ['LightCurve']
37+
__all__ = ['LightCurve', 'FEATURE_EXTRACTOR_REGISTRY']
3838

39+
FEATURE_EXTRACTOR_REGISTRY = {}
3940

4041
class LightCurve:
4142
""" Light Curve object, holding meta and photometric data.
@@ -89,8 +90,6 @@ class LightCurve:
8990
Check if this light can be queried in a given day.
9091
conv_flux_mag(flux: np.array)
9192
Convert positive flux into magnitude.
92-
evaluate_bazin(param: list, time: np.array) -> np.array
93-
Evaluate the Bazin function given parameter values.
9493
load_snpcc_lc(path_to_data: str)
9594
Reads header and photometric information for 1 light curve.
9695
load_plasticc_lc(photo_file: str, snid: int)
@@ -121,6 +120,13 @@ def __init__(self):
121120
self.sncode = 0
122121
self.sntype = ' '
123122

123+
def __init_subclass__(cls):
124+
"""Register all subclasses of LightCurve in the FEATURE_EXTRACTOR_REGISTRY."""
125+
if cls.__name__ in FEATURE_EXTRACTOR_REGISTRY:
126+
raise ValueError(f"Duplicate feature extractor name: {cls.__name__}")
127+
128+
FEATURE_EXTRACTOR_REGISTRY[cls.__name__] = cls
129+
124130
def fit(self, band: str) -> np.ndarray:
125131
"""
126132
Extract features for one filter.
@@ -329,7 +335,7 @@ def check_queryable(self, mjd: float, filter_lim: float, criteria: int =1,
329335
self.last_mag = self.conv_flux_mag([fitted_flux])[0]
330336

331337
else:
332-
raise ValueError('Only "Bazin" and "malanchev" features are implemented!')
338+
raise ValueError('Only "Bazin" and "Malanchev" features are implemented!')
333339

334340
elif sum(surv_flag):
335341
raise ValueError('Criteria needs to be "1" or "2". \n ' + \

0 commit comments

Comments
 (0)