Skip to content

Commit

Permalink
TOM updates
Browse files Browse the repository at this point in the history
  • Loading branch information
AmandaWasserman committed May 29, 2024
1 parent 4048019 commit 8d9b7a3
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 36 deletions.
45 changes: 38 additions & 7 deletions resspect/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def load_features_from_file(self, path_to_features_file: str, screen=False,
'rp1', 'rp2', 'rp3', 'rmax_flux',
'ip1', 'ip2', 'ip3', 'imax_flux',
'zp1', 'zp2', 'zp3', 'zmax_flux']

elif feature_extractor == 'malanchev':
self.features_names = ['ganderson_darling_normal','ginter_percentile_range_5',
'gchi2','gstetson_K','gweighted_mean','gduration',
Expand Down Expand Up @@ -310,12 +309,44 @@ def load_features_from_file(self, path_to_features_file: str, screen=False,
self.metadata_names = self.metadata_names + ['cost_' + name]

elif survey == 'LSST':
self.features_names = ['uA', 'uB', 'ut0', 'utfall', 'utrise',
'gA', 'gB', 'gt0', 'gtfall', 'gtrise',
'rA', 'rB', 'rt0', 'rtfall', 'rtrise',
'iA', 'iB', 'it0', 'itfall', 'itrise',
'zA', 'zB', 'zt0', 'ztfall', 'ztrise',
'YA', 'YB', 'Yt0', 'Ytfall', 'Ytrise']
if feature_extractor == "bazin":
self.features_names = ['uA', 'uB', 'ut0', 'utfall', 'utrise',
'gA', 'gB', 'gt0', 'gtfall', 'gtrise',
'rA', 'rB', 'rt0', 'rtfall', 'rtrise',
'iA', 'iB', 'it0', 'itfall', 'itrise',
'zA', 'zB', 'zt0', 'ztfall', 'ztrise',
'YA', 'YB', 'Yt0', 'Ytfall', 'Ytrise']
elif feature_extractor == "malanchev":
self.features_names = ['uanderson_darling_normal','uinter_percentile_range_5',
'uchi2','ustetson_K','uweighted_mean','uduration',
'uotsu_mean_diff','uotsu_std_lower', 'uotsu_std_upper',
'uotsu_lower_to_all_ratio', 'ulinear_fit_slope',
'ulinear_fit_slope_sigma','ulinear_fit_reduced_chi2',
'ganderson_darling_normal','ginter_percentile_range_5',
'gchi2','gstetson_K','gweighted_mean','gduration',
'gotsu_mean_diff','gotsu_std_lower', 'gotsu_std_upper',
'gotsu_lower_to_all_ratio', 'glinear_fit_slope',
'glinear_fit_slope_sigma','glinear_fit_reduced_chi2',
'randerson_darling_normal', 'rinter_percentile_range_5',
'rchi2', 'rstetson_K', 'rweighted_mean','rduration',
'rotsu_mean_diff','rotsu_std_lower', 'rotsu_std_upper',
'rotsu_lower_to_all_ratio', 'rlinear_fit_slope',
'rlinear_fit_slope_sigma','rlinear_fit_reduced_chi2',
'ianderson_darling_normal','iinter_percentile_range_5',
'ichi2', 'istetson_K', 'iweighted_mean','iduration',
'iotsu_mean_diff','iotsu_std_lower', 'iotsu_std_upper',
'iotsu_lower_to_all_ratio', 'ilinear_fit_slope',
'ilinear_fit_slope_sigma','ilinear_fit_reduced_chi2',
'zanderson_darling_normal','zinter_percentile_range_5',
'zchi2', 'zstetson_K', 'zweighted_mean','zduration',
'zotsu_mean_diff','zotsu_std_lower', 'zotsu_std_upper',
'zotsu_lower_to_all_ratio', 'zlinear_fit_slope',
'zlinear_fit_slope_sigma','zlinear_fit_reduced_chi2',
'Yanderson_darling_normal','Yinter_percentile_range_5',
'Ychi2', 'Ystetson_K', 'Yweighted_mean','Yduration',
'Yotsu_mean_diff','Yotsu_std_lower', 'Yotsu_std_upper',
'Yotsu_lower_to_all_ratio', 'Ylinear_fit_slope',
'Ylinear_fit_slope_sigma','Ylinear_fit_reduced_chi2']

if 'objid' in data.keys():
self.metadata_names = ['objid', 'redshift', 'type', 'code',
Expand Down
43 changes: 17 additions & 26 deletions resspect/fit_lightcurves.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def fit_plasticc(path_photo_file: str, path_header_file: str,
logging.info("Features have been saved to: %s", output_file)

def _TOM_sample_fit(
id: str, dic: dict, feature_extractor: str):
obj_dic: dict, feature_extractor: str):
"""
Reads SNPCC file and performs fit.
Expand All @@ -255,20 +255,20 @@ def _TOM_sample_fit(
Options are 'bazin', 'bump', or 'malanchev'.
"""
light_curve_data = FEATURE_EXTRACTOR_MAPPING[feature_extractor]()
light_curve_data.photometry = pd.DataFrame(dic[id]['photometry'])
light_curve_data.photometry = pd.DataFrame(obj_dic['photometry'])
light_curve_data.dataset_name = 'TOM'
light_curve_data.filters = ['u', 'g', 'r', 'i', 'z', 'Y']
light_curve_data.id = id
light_curve_data.redshift = dic[id]['redshift']
light_curve_data.id = obj_dic['objectid']
light_curve_data.redshift = obj_dic['redshift']
light_curve_data.sntype = 'unknown'
light_curve_data.sncode = dic[id]['sncode']
light_curve_data.sncode = obj_dic['sncode']
light_curve_data.sample = 'N/A'

light_curve_data.fit_all()

return light_curve_data

def fit_TOM(data_dic: dict, features_file: str,
def fit_TOM(data_dic: dict, output_features_file: str,
number_of_processors: int = 1,
feature_extractor: str = 'bazin'):
"""
Expand All @@ -278,7 +278,7 @@ def fit_TOM(data_dic: dict, features_file: str,
----------
data_dic: str
Dictionary containing the photometry for all light curves.
features_file: str
output_features_file: str
Path to output file where results should be stored.
number_of_processors: int, default 1
Number of cpu processes to use.
Expand All @@ -289,45 +289,36 @@ def fit_TOM(data_dic: dict, features_file: str,
header = TOM_FEATURES_HEADER
elif feature_extractor == 'malanchev':
header = TOM_MALANCHEV_FEATURES_HEADER

multi_process = multiprocessing.Pool(number_of_processors)
logging.info("Starting TOM " + feature_extractor + " fit...")
with open(features_file, 'w') as snpcc_features_file:
snpcc_features_file.write(','.join(header) + '\n')
with open(output_features_file, 'w') as TOM_features_file:
TOM_features_file.write(','.join(header) + '\n')

for light_curve_data in multi_process.starmap(
_TOM_sample_fit, zip(
data_dic, repeat(data_dic), repeat(feature_extractor))):
data_dic, repeat(feature_extractor))):
if 'None' not in light_curve_data.features:
write_features_to_output_file(
light_curve_data, snpcc_features_file)
logging.info("Features have been saved to: %s", features_file)
light_curve_data, TOM_features_file)
logging.info("Features have been saved to: %s", output_features_file)

def request_TOM_data(url: str = "https://desc-tom-2.lbl.gov", username: str = None,
passwordfile: str = None, password: str = None, detected_since_mjd: float = None,
detected_in_last_days: float = None,):
detected_in_last_days: float = None, mjdnow: float = None):
tom = TomClient(url = url, username = username, passwordfile = passwordfile,
password = password)
dic = {}
if detected_since_mjd is not None:
dic['detected_since_mjd'] = detected_since_mjd
if detected_in_last_days is not None:
dic['detected_in_last_days'] = detected_in_last_days
res = tom.post('elasticc2/gethotsne', dic)
if mjdnow is not None:
dic['mjd_now'] = mjdnow
res = tom.post('elasticc2/gethottransients', json = dic)
data_dic = res.json()
return data_dic

def submit_queries_to_TOM(objectids: list, priorities: list, requester: str='resspect'):
req = { 'requester': requester,
'objectids': objectids,
'priorities': priorities}
res = TomClient.request( 'POST', 'elasticc2/askforspectrum', json=req )
dic = res.json()
if res.satus_code != 200:
raise ValueError('Request failed, ' + res.text + ". Status code: " + str(res.status_code))

if dic['status'] == 'error':
raise ValueError('Request failed, ' + dic.json()['error'])

def main():
return None
Expand Down
8 changes: 5 additions & 3 deletions resspect/time_domain_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ['time_domain_loop', 'load_dataset']
__all__ = ['time_domain_loop', 'load_dataset', 'submit_queries_to_TOM']

import os
from typing import Union, Tuple
Expand All @@ -24,6 +24,7 @@
import progressbar

from resspect import DataBase
from resspect.tom_client import TomClient


def load_dataset(file_names_dict: dict, survey_name: str = 'DES',
Expand Down Expand Up @@ -799,11 +800,12 @@ def process_next_day_loop(
return light_curve_data


def submit_queries_to_TOM(objectids: list, priorities: list, requester: str='resspect'):
def submit_queries_to_TOM(username, passwordfile, objectids: list, priorities: list, requester: str='resspect'):
tom = TomClient(url = "https://desc-tom-2.lbl.gov", username = username, passwordfile = passwordfile)
req = { 'requester': requester,
'objectids': objectids,
'priorities': priorities}
res = TomClient.request( 'POST', 'elasticc2/askforspectrum', json=req )
res = tom.request( 'POST', 'elasticc2/askforspectrum', json=req )
dic = res.json()
if res.satus_code != 200:
raise ValueError('Request failed, ' + res.text + ". Status code: " + str(res.status_code))
Expand Down

0 comments on commit 8d9b7a3

Please sign in to comment.