Skip to content

Commit 3e065c0

Browse files
committed
pull in 13.9rc1 changes
1 parent 51dd0dd commit 3e065c0

File tree

6 files changed

+149
-60
lines changed

6 files changed

+149
-60
lines changed

ads/opctl/operator/lowcode/common/utils.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import tempfile
1212
from typing import List, Union
1313

14+
import cloudpickle
1415
import fsspec
1516
import oracledb
1617
import pandas as pd
@@ -126,7 +127,26 @@ def load_data(data_spec, storage_options=None, **kwargs):
126127
return data
127128

128129

130+
def _safe_write(fn, **kwargs):
131+
try:
132+
fn(**kwargs)
133+
except Exception:
134+
logger.warning(f'Failed to write file {kwargs.get("filename", "UNKNOWN")}')
135+
136+
129137
def write_data(data, filename, format, storage_options=None, index=False, **kwargs):
138+
return _safe_write(
139+
fn=_write_data,
140+
data=data,
141+
filename=filename,
142+
format=format,
143+
storage_options=storage_options,
144+
index=index,
145+
**kwargs,
146+
)
147+
148+
149+
def _write_data(data, filename, format, storage_options=None, index=False, **kwargs):
130150
disable_print()
131151
if not format:
132152
_, format = os.path.splitext(filename)
@@ -143,11 +163,24 @@ def write_data(data, filename, format, storage_options=None, index=False, **kwar
143163

144164

145165
def write_json(json_dict, filename, storage_options=None):
166+
return _safe_write(
167+
fn=_write_json,
168+
json_dict=json_dict,
169+
filename=filename,
170+
storage_options=storage_options,
171+
)
172+
173+
174+
def _write_json(json_dict, filename, storage_options=None):
146175
with fsspec.open(filename, mode="w", **storage_options) as f:
147176
f.write(json.dumps(json_dict))
148177

149178

150179
def write_simple_json(data, path):
180+
return _safe_write(fn=_write_simple_json, data=data, path=path)
181+
182+
183+
def _write_simple_json(data, path):
151184
if ObjectStorageDetails.is_oci_path(path):
152185
storage_options = default_signer()
153186
else:
@@ -156,6 +189,60 @@ def write_simple_json(data, path):
156189
json.dump(data, f, indent=4)
157190

158191

192+
def write_file(local_filename, remote_filename, storage_options, **kwargs):
193+
return _safe_write(
194+
fn=_write_file,
195+
local_filename=local_filename,
196+
remote_filename=remote_filename,
197+
storage_options=storage_options,
198+
**kwargs,
199+
)
200+
201+
202+
def _write_file(local_filename, remote_filename, storage_options, **kwargs):
203+
with open(local_filename) as f1:
204+
with fsspec.open(
205+
remote_filename,
206+
"w",
207+
**storage_options,
208+
) as f2:
209+
f2.write(f1.read())
210+
211+
212+
def load_pkl(filepath):
213+
return _safe_write(fn=_load_pkl, filepath=filepath)
214+
215+
216+
def _load_pkl(filepath):
217+
storage_options = {}
218+
if ObjectStorageDetails.is_oci_path(filepath):
219+
storage_options = default_signer()
220+
221+
with fsspec.open(filepath, "rb", **storage_options) as f:
222+
return cloudpickle.load(f)
223+
return None
224+
225+
226+
def write_pkl(obj, filename, output_dir, storage_options):
227+
return _safe_write(
228+
fn=_write_pkl,
229+
obj=obj,
230+
filename=filename,
231+
output_dir=output_dir,
232+
storage_options=storage_options,
233+
)
234+
235+
236+
def _write_pkl(obj, filename, output_dir, storage_options):
237+
pkl_path = os.path.join(output_dir, filename)
238+
with fsspec.open(
239+
pkl_path,
240+
"wb",
241+
**storage_options,
242+
) as f:
243+
cloudpickle.dump(obj, f)
244+
245+
159246
def merge_category_columns(data, target_category_columns):
160247
result = data.apply(
161248
lambda x: "__".join([str(x[col]) for col in target_category_columns]), axis=1
@@ -290,4 +377,8 @@ def disable_print():
290377

291378
# Restore
292379
def enable_print():
380+
try:
381+
sys.stdout.close()
382+
except Exception:
383+
pass
293384
sys.stdout = sys.__stdout__

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
3838
super().__init__(config, datasets)
3939
self.global_explanation = {}
4040
self.local_explanation = {}
41+
self.explainability_kwargs = {}
4142

4243
def set_kwargs(self):
4344
model_kwargs_cleaned = self.spec.model_kwargs
@@ -54,6 +55,9 @@ def set_kwargs(self):
5455
self.spec.preprocessing.enabled
5556
or model_kwargs_cleaned.get("preprocessing", True)
5657
)
58+
sample_ratio = model_kwargs_cleaned.pop("sample_to_feature_ratio", None)
59+
if sample_ratio is not None:
60+
self.explainability_kwargs = {"sample_to_feature_ratio": sample_ratio}
5761
return model_kwargs_cleaned, time_budget
5862

5963
def preprocess(self, data, series_id): # TODO: re-use self.le for explanations
@@ -445,6 +449,7 @@ def explain_model(self):
445449
else None,
446450
pd.DataFrame(data_i[self.spec.target_column]),
447451
task="forecasting",
452+
**self.explainability_kwargs,
448453
)
449454

450455
# Generate explanations for the forecast
@@ -518,7 +523,9 @@ def get_validation_score_and_metric(self, model):
518523
model_params = model.selected_model_params_
519524
if len(trials) > 0:
520525
score_col = [col for col in trials.columns if "Score" in col][0]
521-
validation_score = trials[trials.Hyperparameters == model_params][score_col].iloc[0]
526+
validation_score = trials[trials.Hyperparameters == model_params][
527+
score_col
528+
].iloc[0]
522529
else:
523530
validation_score = 0
524531
return -1 * validation_score
@@ -531,8 +538,12 @@ def generate_train_metrics(self) -> pd.DataFrame:
531538
for s_id in self.forecast_output.list_series_ids():
532539
try:
533540
metrics = {self.spec.metric.upper(): self.models[s_id]["score"]}
534-
metrics_df = pd.DataFrame.from_dict(metrics, orient="index", columns=[s_id])
535-
logger.warning("AutoMLX failed to generate training metrics. Recovering validation loss instead")
541+
metrics_df = pd.DataFrame.from_dict(
542+
metrics, orient="index", columns=[s_id]
543+
)
544+
logger.warning(
545+
"AutoMLX failed to generate training metrics. Recovering validation loss instead"
546+
)
536547
total_metrics = pd.concat([total_metrics, metrics_df], axis=1)
537548
except Exception as e:
538549
logger.debug(

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from abc import ABC, abstractmethod
1212
from typing import Tuple
1313

14-
import fsspec
1514
import numpy as np
1615
import pandas as pd
1716
import report_creator as rc
@@ -25,10 +24,13 @@
2524
disable_print,
2625
enable_print,
2726
human_time_friendly,
27+
load_pkl,
2828
merged_category_column_name,
2929
seconds_to_datetime,
3030
write_data,
31+
write_file,
3132
write_json,
33+
write_pkl,
3234
)
3335
from ads.opctl.operator.lowcode.forecast.utils import (
3436
_build_metrics_df,
@@ -38,8 +40,6 @@
3840
evaluate_train_metrics,
3941
get_auto_select_plot,
4042
get_forecast_plots,
41-
load_pkl,
42-
write_pkl,
4343
)
4444

4545
from ..const import (
@@ -493,13 +493,11 @@ def _save_report(
493493
enable_print()
494494

495495
report_path = os.path.join(unique_output_dir, self.spec.report_filename)
496-
with open(report_local_path) as f1:
497-
with fsspec.open(
498-
report_path,
499-
"w",
500-
**storage_options,
501-
) as f2:
502-
f2.write(f1.read())
496+
write_file(
497+
local_filename=report_local_path,
498+
remote_filename=report_path,
499+
storage_options=storage_options,
500+
)
503501

504502
# forecast csv report
505503
# todo: add test data into forecast.csv
@@ -576,7 +574,9 @@ def _save_report(
576574
# Round to 4 decimal places before writing
577575
global_expl_rounded = self.formatted_global_explanation.copy()
578576
global_expl_rounded = global_expl_rounded.apply(
579-
lambda col: np.round(col, 4) if np.issubdtype(col.dtype, np.number) else col
577+
lambda col: np.round(col, 4)
578+
if np.issubdtype(col.dtype, np.number)
579+
else col
580580
)
581581
if self.spec.generate_explanation_files:
582582
write_data(
@@ -598,7 +598,9 @@ def _save_report(
598598
# Round to 4 decimal places before writing
599599
local_expl_rounded = self.formatted_local_explanation.copy()
600600
local_expl_rounded = local_expl_rounded.apply(
601-
lambda col: np.round(col, 4) if np.issubdtype(col.dtype, np.number) else col
601+
lambda col: np.round(col, 4)
602+
if np.issubdtype(col.dtype, np.number)
603+
else col
602604
)
603605
if self.spec.generate_explanation_files:
604606
write_data(

ads/opctl/operator/lowcode/forecast/model/neuralprophet.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@
1919
from ads.opctl.operator.lowcode.common.utils import (
2020
disable_print,
2121
enable_print,
22-
)
23-
from ads.opctl.operator.lowcode.forecast.utils import (
24-
_select_plot_list,
2522
load_pkl,
2623
write_pkl,
2724
)
25+
from ads.opctl.operator.lowcode.forecast.utils import _select_plot_list
2826

2927
from ..const import DEFAULT_TRIALS, SupportedModels
3028
from ..operator_config import ForecastOperatorConfig
@@ -159,20 +157,18 @@ def _train_model(self, i, s_id, df, model_kwargs):
159157
upper_bound=self.get_horizon(forecast[upper_bound_col_name]).values,
160158
lower_bound=self.get_horizon(forecast[lower_bound_col_name]).values,
161159
)
162-
core_columns = set(forecast.columns) - set(
163-
[
164-
"y",
165-
"yhat1",
166-
upper_bound_col_name,
167-
lower_bound_col_name,
168-
"future_regressors_additive",
169-
"future_regressors_multiplicative",
170-
]
171-
)
160+
core_columns = set(forecast.columns) - {
161+
"y",
162+
"yhat1",
163+
upper_bound_col_name,
164+
lower_bound_col_name,
165+
"future_regressors_additive",
166+
"future_regressors_multiplicative",
167+
}
172168
exog_variables = set(
173169
filter(lambda x: x.startswith("future_regressor_"), list(core_columns))
174170
)
175-
combine_terms = list(core_columns - exog_variables - set(["ds"]))
171+
combine_terms = list(core_columns - exog_variables - {"ds"})
176172
temp_df = (
177173
forecast[list(core_columns)]
178174
.rename({"ds": "Date"}, axis=1)

ads/opctl/operator/lowcode/forecast/utils.py

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import logging
77
import os
88
from typing import Set
99

10-
import cloudpickle
11-
import fsspec
1210
import numpy as np
1311
import pandas as pd
1412
import report_creator as rc
@@ -21,7 +19,6 @@
2119
r2_score,
2220
)
2321

24-
from ads.common.object_storage_details import ObjectStorageDetails
2522
from ads.dataset.label_encoder import DataFrameLabelEncoder
2623
from ads.opctl import logger
2724
from ads.opctl.operator.lowcode.forecast.const import ForecastOutputColumns
@@ -170,26 +167,6 @@ def _build_metrics_per_horizon(
170167
return metrics_df
171168

172169

173-
def load_pkl(filepath):
174-
storage_options = {}
175-
if ObjectStorageDetails.is_oci_path(filepath):
176-
storage_options = default_signer()
177-
178-
with fsspec.open(filepath, "rb", **storage_options) as f:
179-
return cloudpickle.load(f)
180-
return None
181-
182-
183-
def write_pkl(obj, filename, output_dir, storage_options):
184-
pkl_path = os.path.join(output_dir, filename)
185-
with fsspec.open(
186-
pkl_path,
187-
"wb",
188-
**storage_options,
189-
) as f:
190-
cloudpickle.dump(obj, f)
191-
192-
193170
def _build_metrics_df(y_true, y_pred, series_id):
194171
if len(y_true) == 0 or len(y_pred) == 0:
195172
return pd.DataFrame()
@@ -251,7 +228,10 @@ def evaluate_train_metrics(output):
251228

252229

253230
def _select_plot_list(fn, series_ids, target_category_column):
254-
blocks = [rc.Widget(fn(s_id=s_id), label=s_id if target_category_column else None) for s_id in series_ids]
231+
blocks = [
232+
rc.Widget(fn(s_id=s_id), label=s_id if target_category_column else None)
233+
for s_id in series_ids
234+
]
255235
return rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
256236

257237

@@ -264,8 +244,10 @@ def get_auto_select_plot(backtest_results):
264244
back_test_csv_columns = backtest_results.columns.tolist()
265245
back_test_column = "backtest"
266246
metric_column = "metric"
267-
models = [x for x in back_test_csv_columns if x not in [back_test_column, metric_column]]
268-
for i, column in enumerate(models):
247+
models = [
248+
x for x in back_test_csv_columns if x not in [back_test_column, metric_column]
249+
]
250+
for column in models:
269251
fig.add_trace(
270252
go.Scatter(
271253
x=backtest_results[back_test_column],
@@ -283,7 +265,7 @@ def get_forecast_plots(
283265
horizon,
284266
test_data=None,
285267
ci_interval_width=0.95,
286-
target_category_column=None
268+
target_category_column=None,
287269
):
288270
def plot_forecast_plotly(s_id):
289271
fig = go.Figure()
@@ -380,7 +362,9 @@ def plot_forecast_plotly(s_id):
380362
)
381363
return fig
382364

383-
return _select_plot_list(plot_forecast_plotly, forecast_output.list_series_ids(), target_category_column)
365+
return _select_plot_list(
366+
plot_forecast_plotly, forecast_output.list_series_ids(), target_category_column
367+
)
384368

385369

386370
def convert_target(target: str, target_col: str):

0 commit comments

Comments
 (0)