diff --git a/resspect/database.py b/resspect/database.py index 7246d7c1..0da3f926 100644 --- a/resspect/database.py +++ b/resspect/database.py @@ -276,7 +276,6 @@ def load_features_from_file(self, path_to_features_file: str, screen=False, 'rp1', 'rp2', 'rp3', 'rmax_flux', 'ip1', 'ip2', 'ip3', 'imax_flux', 'zp1', 'zp2', 'zp3', 'zmax_flux'] - elif feature_extractor == 'malanchev': self.features_names = ['ganderson_darling_normal','ginter_percentile_range_5', 'gchi2','gstetson_K','gweighted_mean','gduration', @@ -310,12 +309,44 @@ def load_features_from_file(self, path_to_features_file: str, screen=False, self.metadata_names = self.metadata_names + ['cost_' + name] elif survey == 'LSST': - self.features_names = ['uA', 'uB', 'ut0', 'utfall', 'utrise', - 'gA', 'gB', 'gt0', 'gtfall', 'gtrise', - 'rA', 'rB', 'rt0', 'rtfall', 'rtrise', - 'iA', 'iB', 'it0', 'itfall', 'itrise', - 'zA', 'zB', 'zt0', 'ztfall', 'ztrise', - 'YA', 'YB', 'Yt0', 'Ytfall', 'Ytrise'] + if feature_extractor == "bazin": + self.features_names = ['uA', 'uB', 'ut0', 'utfall', 'utrise', + 'gA', 'gB', 'gt0', 'gtfall', 'gtrise', + 'rA', 'rB', 'rt0', 'rtfall', 'rtrise', + 'iA', 'iB', 'it0', 'itfall', 'itrise', + 'zA', 'zB', 'zt0', 'ztfall', 'ztrise', + 'YA', 'YB', 'Yt0', 'Ytfall', 'Ytrise'] + elif feature_extractor == "malanchev": + self.features_names = ['uanderson_darling_normal','uinter_percentile_range_5', + 'uchi2','ustetson_K','uweighted_mean','uduration', + 'uotsu_mean_diff','uotsu_std_lower', 'uotsu_std_upper', + 'uotsu_lower_to_all_ratio', 'ulinear_fit_slope', + 'ulinear_fit_slope_sigma','ulinear_fit_reduced_chi2', + 'ganderson_darling_normal','ginter_percentile_range_5', + 'gchi2','gstetson_K','gweighted_mean','gduration', + 'gotsu_mean_diff','gotsu_std_lower', 'gotsu_std_upper', + 'gotsu_lower_to_all_ratio', 'glinear_fit_slope', + 'glinear_fit_slope_sigma','glinear_fit_reduced_chi2', + 'randerson_darling_normal', 'rinter_percentile_range_5', + 'rchi2', 'rstetson_K', 'rweighted_mean','rduration', + 'rotsu_mean_diff','rotsu_std_lower', 'rotsu_std_upper', + 'rotsu_lower_to_all_ratio', 'rlinear_fit_slope', + 'rlinear_fit_slope_sigma','rlinear_fit_reduced_chi2', + 'ianderson_darling_normal','iinter_percentile_range_5', + 'ichi2', 'istetson_K', 'iweighted_mean','iduration', + 'iotsu_mean_diff','iotsu_std_lower', 'iotsu_std_upper', + 'iotsu_lower_to_all_ratio', 'ilinear_fit_slope', + 'ilinear_fit_slope_sigma','ilinear_fit_reduced_chi2', + 'zanderson_darling_normal','zinter_percentile_range_5', + 'zchi2', 'zstetson_K', 'zweighted_mean','zduration', + 'zotsu_mean_diff','zotsu_std_lower', 'zotsu_std_upper', + 'zotsu_lower_to_all_ratio', 'zlinear_fit_slope', + 'zlinear_fit_slope_sigma','zlinear_fit_reduced_chi2', + 'Yanderson_darling_normal','Yinter_percentile_range_5', + 'Ychi2', 'Ystetson_K', 'Yweighted_mean','Yduration', + 'Yotsu_mean_diff','Yotsu_std_lower', 'Yotsu_std_upper', + 'Yotsu_lower_to_all_ratio', 'Ylinear_fit_slope', + 'Ylinear_fit_slope_sigma','Ylinear_fit_reduced_chi2'] if 'objid' in data.keys(): self.metadata_names = ['objid', 'redshift', 'type', 'code', diff --git a/resspect/fit_lightcurves.py b/resspect/fit_lightcurves.py index 48e77191..3d24a86a 100644 --- a/resspect/fit_lightcurves.py +++ b/resspect/fit_lightcurves.py @@ -242,7 +242,7 @@ def fit_plasticc(path_photo_file: str, path_header_file: str, logging.info("Features have been saved to: %s", output_file) def _TOM_sample_fit( - id: str, dic: dict, feature_extractor: str): + obj_dic: dict, feature_extractor: str): """ Reads SNPCC file and performs fit. @@ -255,20 +255,20 @@ def _TOM_sample_fit( Options are 'bazin', 'bump', or 'malanchev'. """ light_curve_data = FEATURE_EXTRACTOR_MAPPING[feature_extractor]() - light_curve_data.photometry = pd.DataFrame(dic[id]['photometry']) + light_curve_data.photometry = pd.DataFrame(obj_dic['photometry']) light_curve_data.dataset_name = 'TOM' light_curve_data.filters = ['u', 'g', 'r', 'i', 'z', 'Y'] - light_curve_data.id = id - light_curve_data.redshift = dic[id]['redshift'] + light_curve_data.id = obj_dic['objectid'] + light_curve_data.redshift = obj_dic['redshift'] light_curve_data.sntype = 'unknown' - light_curve_data.sncode = dic[id]['sncode'] + light_curve_data.sncode = obj_dic['sncode'] light_curve_data.sample = 'N/A' light_curve_data.fit_all() return light_curve_data -def fit_TOM(data_dic: dict, features_file: str, +def fit_TOM(data_dic: dict, output_features_file: str, number_of_processors: int = 1, feature_extractor: str = 'bazin'): """ @@ -278,7 +278,7 @@ def fit_TOM(data_dic: dict, features_file: str, ---------- data_dic: str Dictionary containing the photometry for all light curves. - features_file: str + output_features_file: str Path to output file where results should be stored. number_of_processors: int, default 1 Number of cpu processes to use. @@ -289,23 +289,23 @@ def fit_TOM(data_dic: dict, features_file: str, header = TOM_FEATURES_HEADER elif feature_extractor == 'malanchev': header = TOM_MALANCHEV_FEATURES_HEADER - + multi_process = multiprocessing.Pool(number_of_processors) logging.info("Starting TOM " + feature_extractor + " fit...") - with open(features_file, 'w') as snpcc_features_file: - snpcc_features_file.write(','.join(header) + '\n') + with open(output_features_file, 'w') as TOM_features_file: + TOM_features_file.write(','.join(header) + '\n') for light_curve_data in multi_process.starmap( _TOM_sample_fit, zip( - data_dic, repeat(data_dic), repeat(feature_extractor))): + data_dic, repeat(feature_extractor))): if 'None' not in light_curve_data.features: write_features_to_output_file( - light_curve_data, snpcc_features_file) - logging.info("Features have been saved to: %s", features_file) + light_curve_data, TOM_features_file) + logging.info("Features have been saved to: %s", output_features_file) def request_TOM_data(url: str = "https://desc-tom-2.lbl.gov", username: str = None, passwordfile: str = None, password: str = None, detected_since_mjd: float = None, - detected_in_last_days: float = None,): + detected_in_last_days: float = None, mjdnow: float = None): tom = TomClient(url = url, username = username, passwordfile = passwordfile, password = password) dic = {} @@ -313,21 +313,12 @@ def request_TOM_data(url: str = "https://desc-tom-2.lbl.gov", username: str = No dic['detected_since_mjd'] = detected_since_mjd if detected_in_last_days is not None: dic['detected_in_last_days'] = detected_in_last_days - res = tom.post('elasticc2/gethotsne', dic) + if mjdnow is not None: + dic['mjd_now'] = mjdnow + res = tom.post('elasticc2/gethottransients', json = dic) data_dic = res.json() return data_dic -def submit_queries_to_TOM(objectids: list, priorities: list, requester: str='resspect'): - req = { 'requester': requester, - 'objectids': objectids, - 'priorities': priorities} - res = TomClient.request( 'POST', 'elasticc2/askforspectrum', json=req ) - dic = res.json() - if res.satus_code != 200: - raise ValueError('Request failed, ' + res.text + ". Status code: " + str(res.status_code)) - - if dic['status'] == 'error': - raise ValueError('Request failed, ' + dic.json()['error']) def main(): return None diff --git a/resspect/time_domain_loop.py b/resspect/time_domain_loop.py index 67f2a67d..bc044d37 100644 --- a/resspect/time_domain_loop.py +++ b/resspect/time_domain_loop.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ['time_domain_loop', 'load_dataset'] +__all__ = ['time_domain_loop', 'load_dataset', 'submit_queries_to_TOM'] import os from typing import Union, Tuple @@ -24,6 +24,7 @@ import progressbar from resspect import DataBase +from resspect.tom_client import TomClient def load_dataset(file_names_dict: dict, survey_name: str = 'DES', @@ -799,11 +800,12 @@ def process_next_day_loop( return light_curve_data -def submit_queries_to_TOM(objectids: list, priorities: list, requester: str='resspect'): +def submit_queries_to_TOM(username, passwordfile, objectids: list, priorities: list, requester: str='resspect'): + tom = TomClient(url = "https://desc-tom-2.lbl.gov", username = username, passwordfile = passwordfile) req = { 'requester': requester, 'objectids': objectids, 'priorities': priorities} - res = TomClient.request( 'POST', 'elasticc2/askforspectrum', json=req ) + res = tom.request( 'POST', 'elasticc2/askforspectrum', json=req ) dic = res.json() if res.satus_code != 200: raise ValueError('Request failed, ' + res.text + ". Status code: " + str(res.status_code))