Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generalize fit() for non Ia applications #56

Merged
merged 3 commits into from
Oct 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions src/resspect/fit_lightcurves.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def fit_TOM(data_dic: dict, output_features_file: str,
logging.info("Features have been saved to: %s", output_features_file)

def _sample_fit(
obj_dic: dict, feature_extractor: str, filters: list, type: str, Ia_code: list,
obj_dic: dict, feature_extractor: str, filters: list, type: str, one_code: list,
additional_info: list):
"""
Reads general file and performs fit.
Expand All @@ -328,29 +328,32 @@ def _sample_fit(
light_curve_data.id = obj_dic['objectid']
light_curve_data.redshift = obj_dic['redshift']
light_curve_data.sncode = obj_dic['sncode']
if light_curve_data.sncode in Ia_code:
light_curve_data.sntype = 'Ia'
if light_curve_data.sncode in one_code:
light_curve_data.sntype = 'Ia' #just labeling all positive classes as Ia
#unsure what changing this might affect in database.py
else:
light_curve_data.sntype = 'other'
light_curve_data.sample = type
light_curve_data.additional_info = additional_info

light_curve_data.additional_info = []
for info in additional_info:
light_curve_data.additional_info.append(obj_dic[info])
light_curve_data.fit_all()

return light_curve_data

def fit(data_dic: dict, output_features_file: str,
number_of_processors: int = MAX_NUMBER_OF_PROCESSES,
feature_extractor: str = 'bazin', filters: list = ['SNPCC'],
features: list = [], type: str = 'unspecified', Ia_code: list = [10],
features: list = [], type: str = 'unspecified', one_code: list = [10],
additional_info: list = []):
"""
Perform fit to all objects from a generalized dataset.

Parameters
----------
data_dic: str
Dictionary containing the photometry for all light curves.
data_dic: list
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
data_dic: list
data_dic: list[dict]

Minor point - doesn't need to be fixed in this PR - but since data_dic is now expected to be a list, it might make sense to simply call it data.

List of dictionaries containing the photometry for each light
curve. Example provided below.
output_features_file: str
Path to output file where results should be stored.
number_of_processors: int, default 1
Expand All @@ -364,12 +367,32 @@ def fit(data_dic: dict, output_features_file: str,
the feature extractor.
type: str
Type of data: train, test, validation, pool
Ia_code: list
List of Ia codes to be used. Default is 10 from ELAsTiCC.
one_code: list
List of codes to be used to define a positive class.
Default is 10 from ELAsTiCC type Ia SN.
additional_info: list
List of additional header information to be used in other classifiers.
For example, RA and dec for GHOST feature extraction

Example of a data_dic element:
------------------------------
Required keys
'objectid'
object id
'photometry'
dictionary containing keys ''mjd', 'band', 'flux', 'fluxerr'.
each entry contains a list of mjd, band, flux, fluxerr for
each observation
'redshift'
redshift of the object if known, 'unknown' otherwise
'sncode'
number to delineate what type of transient the object is.
one_code will translate this into '1'/'0' for positive/negative

Optional keys; anything you might need for your classifier.
Need to specify the keys as a list in 'additional_info'. For example
'RA'
'dec'
"""
if feature_extractor == 'bazin':
header = TOM_FEATURES_HEADER
Expand Down Expand Up @@ -409,7 +432,7 @@ def fit(data_dic: dict, output_features_file: str,
for light_curve_data in multi_process.starmap(
_sample_fit, zip(
data_dic, repeat(feature_extractor), repeat(filters), repeat(type),
repeat(Ia_code), repeat(additional_info))):
repeat(one_code), repeat(additional_info))):
if 'None' not in light_curve_data.features:
write_features_to_output_file(
light_curve_data, features_file)
Expand Down
Loading