Skip to content

Commit 8a66bd1

Browse files
authored
Switch to Era5t data (#29)
* Add paths and global variables in config * Change lines to switch to ERA5T * Add predict for xgboost model and modify add lags * Modify unitary tests
1 parent 8c840df commit 8a66bd1

File tree

5 files changed

+59
-27
lines changed

5 files changed

+59
-27
lines changed

pyro_risks/config.py

+25
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
TEST_FR_VIIRS_XLSX_FALLBACK: str = f"{DATA_FALLBACK}/test_data_VIIRS.xlsx"
2525
TEST_FR_VIIRS_JSON_FALLBACK: str = f"{DATA_FALLBACK}/test_data_VIIRS.json"
2626
TEST_FR_ERA5_2019_FALLBACK: str = f"{DATA_FALLBACK}/test_data_ERA5_2019.nc"
27+
TEST_FR_ERA5T_FALLBACK: str = f"{DATA_FALLBACK}/test_era5t_to_merge.nc"
2728
TEST_FWI_FALLBACK: str = f"{DATA_FALLBACK}/test_data_FWI.csv"
2829
TEST_FWI_TO_PREDICT: str = f"{DATA_FALLBACK}/fwi_test_to_predict.csv"
2930
TEST_ERA_TO_PREDICT: str = f"{DATA_FALLBACK}/era_test_to_predict.csv"
@@ -37,7 +38,30 @@
3738
CDS_API_KEY = os.getenv('CDS_API_KEY')
3839

3940
RFMODEL_PATH: str = f"{DATA_FALLBACK}/pyrorisk_rfc_111220.pkl"
41+
RFMODEL_ERA5T_PATH: str = f"{DATA_FALLBACK}/pyrorisk_rfc_era5t_151220.pkl"
4042
XGBMODEL_PATH: str = f"{DATA_FALLBACK}/pyrorisk_xgb_091220.pkl"
43+
XGBMODEL_ERA5T_PATH: str = f"{DATA_FALLBACK}/pyrorisk_xgb_era5t_151220.pkl"
44+
45+
FWI_VARS = ['fwi', 'ffmc', 'dmc', 'dc', 'isi', 'bui', 'dsr']
46+
WEATHER_VARS = [
47+
'u10', 'v10', 'd2m', 't2m', 'fal', 'lai_hv', 'lai_lv', 'skt',
48+
'asn', 'snowc', 'rsn', 'sde', 'sd', 'sf', 'smlt', 'stl1', 'stl2',
49+
'stl3', 'stl4', 'slhf', 'ssr', 'str', 'sp', 'sshf', 'ssrd', 'strd', 'tsn', 'tp'
50+
]
51+
WEATHER_ERA5T_VARS = ['asn', 'd2m', 'e', 'es', 'fal', 'lai_hv', 'lai_lv', 'lblt',
52+
'licd', 'lict', 'lmld', 'lmlt', 'lshf', 'ltlt', 'pev', 'ro', 'rsn', 'sd', 'sf', 'skt',
53+
'slhf', 'smlt', 'sp', 'src', 'sro', 'sshf', 'ssr', 'ssrd', 'ssro', 'stl1', 'stl2', 'stl3',
54+
'stl4', 'str', 'strd', 'swvl1', 'swvl2', 'swvl3', 'swvl4', 't2m', 'tp', 'tsn', 'u10', 'v10']
55+
56+
MODEL_ERA5T_VARS = ['str_max', 'str_mean', 'ffmc_min', 'str_min', 'ffmc_mean',
57+
'str_mean_lag1', 'str_max_lag1', 'str_min_lag1', 'isi_min',
58+
'ffmc_min_lag1', 'isi_mean', 'ffmc_mean_lag1', 'ffmc_std', 'ffmc_max',
59+
'isi_min_lag1', 'isi_mean_lag1', 'ffmc_max_lag1', 'asn_std', 'strd_max',
60+
'ssrd_min', 'strd_mean', 'isi_max', 'strd_min', 'd2m_min', 'asn_min',
61+
'ssr_min', 'ffmc_min_lag3', 'ffmc_std_lag1', 'lai_hv_mean_lag7',
62+
'str_max_lag3', 'str_mean_lag3', 'rsn_std_lag1', 'fwi_mean', 'ssr_mean',
63+
'ssrd_mean', 'swvl1_mean', 'rsn_std_lag3', 'isi_max_lag1', 'd2m_mean',
64+
'rsn_std']
4165

4266
MODEL_VARIABLES = ['ffmc_min', 'str_mean', 'str_min', 'str_max', 'ffmc_mean', 'isi_min',
4367
'ffmc_min_lag1', 'strd_mean', 'isi_mean', 'strd_min', 'strd_max',
@@ -48,6 +72,7 @@
4872
'strd_min_lag1', 'ffmc_min_lag3', 'ffmc_std_lag1', 'strd_mean_lag1',
4973
'rsn_mean_lag1', 'fwi_mean', 'isi_max_lag1', 'sd_max', 'strd_max_lag1',
5074
'rsn_mean', 'snowc_std_lag7', 'stl1_std_lag3']
75+
5176
TRAIN_SELECTED_DEP = ['Aisne', 'Alpes-Maritimes', 'Ardèche', 'Ariège', 'Aude', 'Aveyron',
5277
'Cantal', 'Eure', 'Eure-et-Loir', 'Gironde', 'Haute-Corse', 'Hautes-Pyrénées',
5378
'Hérault', 'Indre', 'Landes', 'Loiret', 'Lozère', 'Marne', 'Oise',

pyro_risks/datasets/era_fwi_viirs.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import logging
22
import pandas as pd
33

4-
from pyro_risks.datasets import NASAFIRMS_VIIRS, ERA5Land
4+
from pyro_risks.datasets import NASAFIRMS_VIIRS, ERA5Land, ERA5T
55
from pyro_risks.datasets.utils import get_intersection_range
66
from pyro_risks.datasets.fwi import GwisFwi
7+
from pyro_risks import config as cfg
78

89
__all__ = ["MergedEraFwiViirs"]
910

@@ -28,19 +29,15 @@ def process_dataset_to_predict(fwi, era):
2829

2930
# Group fwi dataframe by day and department and compute min, max, mean, std
3031
agg_fwi_df = fwi_df.groupby(['day', 'nom'])[
31-
'fwi', 'ffmc', 'dmc', 'dc', 'isi', 'bui', 'dsr'
32-
].agg(['min', 'max', 'mean', 'std']).reset_index()
32+
cfg.FWI_VARS].agg(['min', 'max', 'mean', 'std']).reset_index()
3333
agg_fwi_df.columns = ['day', 'nom'] + \
3434
[x[0] + '_' + x[1] for x in agg_fwi_df.columns if x[1] != '']
3535

3636
logger.info("Finished aggregationg of FWI")
3737

3838
# Group weather dataframe by day and department and compute min, max, mean, std
3939
agg_wth_df = weather.groupby(['time', 'nom'])[
40-
'u10', 'v10', 'd2m', 't2m', 'fal', 'lai_hv', 'lai_lv', 'skt',
41-
'asn', 'snowc', 'rsn', 'sde', 'sd', 'sf', 'smlt', 'stl1', 'stl2',
42-
'stl3', 'stl4', 'slhf', 'ssr', 'str', 'sp', 'sshf', 'ssrd', 'strd', 'tsn', 'tp'
43-
].agg(['min', 'max', 'mean', 'std']).reset_index()
40+
cfg.WEATHER_ERA5T_VARS].agg(['min', 'max', 'mean', 'std']).reset_index()
4441
agg_wth_df.columns = ['day', 'nom'] + \
4542
[x[0] + '_' + x[1] for x in agg_wth_df.columns if x[1] != '']
4643

@@ -75,7 +72,7 @@ def __init__(self, era_source_path=None, viirs_source_path=None, fwi_source_path
7572
viirs_source_path (str, optional): Viirs data source path. Defaults to None.
7673
fwi_source_path (str, optional): Fwi data source path. Defaults to None.
7774
"""
78-
weather = ERA5Land(era_source_path)
75+
weather = ERA5T(era_source_path) # ERA5Land(era_source_path)
7976
nasa_firms = NASAFIRMS_VIIRS(viirs_source_path)
8077

8178
# Time span selection
@@ -100,17 +97,13 @@ def __init__(self, era_source_path=None, viirs_source_path=None, fwi_source_path
10097

10198
# Group fwi dataframe by day and department and compute min, max, mean, std
10299
agg_fwi_df = fwi_df.groupby(['day', 'departement'])[
103-
'fwi', 'ffmc', 'dmc', 'dc', 'isi', 'bui', 'dsr'
104-
].agg(['min', 'max', 'mean', 'std']).reset_index()
100+
cfg.FWI_VARS].agg(['min', 'max', 'mean', 'std']).reset_index()
105101
agg_fwi_df.columns = ['day', 'departement'] + \
106102
[x[0] + '_' + x[1] for x in agg_fwi_df.columns if x[1] != '']
107103

108104
# Group weather dataframe by day and department and compute min, max, mean, std
109105
agg_wth_df = weather.groupby(['time', 'nom'])[
110-
'u10', 'v10', 'd2m', 't2m', 'fal', 'lai_hv', 'lai_lv', 'skt',
111-
'asn', 'snowc', 'rsn', 'sde', 'sd', 'sf', 'smlt', 'stl1', 'stl2',
112-
'stl3', 'stl4', 'slhf', 'ssr', 'str', 'sp', 'sshf', 'ssrd', 'strd', 'tsn', 'tp'
113-
].agg(['min', 'max', 'mean', 'std']).reset_index()
106+
cfg.WEATHER_ERA5T_VARS].agg(['min', 'max', 'mean', 'std']).reset_index()
114107
agg_wth_df.columns = ['day', 'departement'] + \
115108
[x[0] + '_' + x[1] for x in agg_wth_df.columns if x[1] != '']
116109

pyro_risks/models/predict.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import joblib
22
from urllib.request import urlopen
3+
import xgboost
34

45
from pyro_risks import config as cfg
56
from pyro_risks.datasets.fwi import get_fwi_data_for_predict
6-
from pyro_risks.datasets.ERA5 import get_data_era5land_for_predict
7+
from pyro_risks.datasets.ERA5 import get_data_era5land_for_predict, get_data_era5t_for_predict
78
from pyro_risks.datasets.era_fwi_viirs import process_dataset_to_predict
89
from pyro_risks.models.score_v0 import add_lags
910

@@ -28,10 +29,13 @@ def __init__(self, which='RF'):
2829
which (str, optional): Can be 'RF' for random forest or 'XGB' for xgboost. Defaults to 'RF'.
2930
"""
3031
if which == 'RF':
31-
self.model_path = cfg.RFMODEL_PATH
32+
self.model_path = cfg.RFMODEL_ERA5T_PATH
3233
elif which == 'XGB':
33-
self.model_path = cfg.XGBMODEL_PATH
34+
self.model_path = cfg.XGBMODEL_ERA5T_PATH
35+
else:
36+
raise ValueError("Model can be only of type RF or XGB")
3437
self.model = joblib.load(urlopen(self.model_path))
38+
self._model_type = which
3539

3640
def get_input(self, day):
3741
"""Returns for a given day data to feed into the model.
@@ -45,12 +49,15 @@ def get_input(self, day):
4549
Returns:
4650
pd.DataFrame
4751
"""
48-
model_cols = cfg.MODEL_VARIABLES
52+
model_cols = cfg.MODEL_ERA5T_VARS
4953
fwi = get_fwi_data_for_predict(day)
50-
era = get_data_era5land_for_predict(day)
54+
era = get_data_era5t_for_predict(day)
5155
res_test = process_dataset_to_predict(fwi, era)
5256
res_test = res_test.rename({'nom': 'departement'}, axis=1)
53-
res_lags = add_lags(res_test, res_test.drop(['day', 'departement'], axis=1).columns)
57+
# Add lags only for columns on which model was trained on
58+
cols_lags = ['_'.join(x.split('_')[:-1]) for x in cfg.MODEL_ERA5T_VARS if '_lag' in x]
59+
res_lags = add_lags(res_test, cols_lags)
60+
# Select only rows corresponding to day
5461
to_predict = res_lags.loc[res_lags['day'] == day]
5562
to_predict = to_predict.drop('day', axis=1).set_index('departement')
5663
# Some NaN due to the aggregations on departments with only one line (variables with std)
@@ -68,9 +75,15 @@ def predict(self, day, country='France'):
6875
country (str, optional): Defaults to 'France'.
6976
7077
Returns:
71-
dict: keys are departements and values model probability predictions for label 1 (fire)
78+
dict: keys are departements, values dictionaries whose keys are score and explainability
79+
and values probability predictions for label 1 (fire) and feature contributions to predictions
80+
respectively
7281
"""
7382
sample = self.get_input(day)
74-
predictions = self.model.predict_proba(sample.values)
75-
res = dict(zip(sample.index, predictions[:, 1].round(3)))
83+
if self._model_type == 'RF':
84+
predictions = self.model.predict_proba(sample.values)
85+
res = dict(zip(sample.index, predictions[:, 1].round(3)))
86+
elif self._model_type == 'XGB':
87+
predictions = self.model.predict(xgboost.DMatrix(sample))
88+
res = dict(zip(sample.index, predictions.round(3)))
7689
return {x: {'score': res[x], 'explainability': None} for x in res}

test/test_datasets.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -428,11 +428,12 @@ def test_era5t(self):
428428

429429
def test_MergedEraFwiViirs(self):
430430
ds = era_fwi_viirs.MergedEraFwiViirs(
431-
era_source_path=cfg.TEST_FR_ERA5_2019_FALLBACK,
431+
era_source_path=cfg.TEST_FR_ERA5T_FALLBACK,
432432
viirs_source_path=None,
433433
fwi_source_path=cfg.TEST_FWI_FALLBACK,
434434
)
435435
self.assertIsInstance(ds, pd.DataFrame)
436+
self.assertTrue(len(ds) > 0)
436437

437438
def test_call_era5land(self):
438439
with tempfile.TemporaryDirectory() as tmp:

test/test_models.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,13 @@ def test_xgb_model(self):
7979
def test_pyrorisk(self):
8080
pr = predict.PyroRisk(which='RF')
8181
self.assertEqual(pr.model.n_estimators, 500)
82-
self.assertEqual(pr.model_path, cfg.RFMODEL_PATH)
82+
self.assertEqual(pr.model_path, cfg.RFMODEL_ERA5T_PATH)
8383
res = pr.get_input('2020-05-05')
8484
self.assertIsInstance(res, pd.DataFrame)
85-
self.assertEqual(res.shape, (93, 41))
85+
self.assertEqual(res.shape, (93, 40))
8686
preds = pr.predict('2020-05-05')
8787
self.assertEqual(len(preds), 93)
88-
self.assertEqual(preds['Ardennes'], {'score': 0.11, 'explainability': None})
88+
self.assertEqual(preds['Ardennes'], {'score': 0.246, 'explainability': None})
8989

9090

9191
if __name__ == "__main__":

0 commit comments

Comments
 (0)