Skip to content

Commit 8b28f97

Browse files
mpolson64facebook-github-bot
authored andcommitted
Back Data with DataRow object (#4773)
Summary: NOTE: This is much slower than the implementation which is backed by a dataframe. For clarity, Ive put this naive implementation up as its own diff and the next diff hunts for speedups. Creates new source of truth for Data: the DataRow. The df is now a cached property which is dynamically generated based on these rows. In the future, these will become a Base object in SQLAlchemy st. Data will have a SQLAlchemy relationship to a list of DataRows which live in their own table. RFC: 1. Im renaming sem -> se here (but keeping sem in the df for now, since this could be an incredibly involved cleanup). Do we have alignment that this is a positive change? If so I can either start of backlog the cleanup across the codebase. cc Balandat who Ive talked about this with a while back. 2. This removes the ability for Data to contain arbitrary columns, which was added in D83682740 and afaik unused. Arbitrary new columns would not be compatible with the new storage setup (it was easy in the old setup which is why we added it), and I think we should take a careful look at how to store contextual data in the future in a structured way. Differential Revision: D90605846
1 parent 0fde92e commit 8b28f97

File tree

12 files changed

+122
-48
lines changed

12 files changed

+122
-48
lines changed

ax/analysis/plotly/tests/test_marginal_effects.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def setUp(self) -> None:
4343
self.experiment.trials[i].mark_running(no_runner_required=True)
4444
self.experiment.attach_data(
4545
Data(
46-
pd.DataFrame(
46+
df=pd.DataFrame(
4747
{
4848
"trial_index": [i] * num_arms,
4949
"arm_name": [f"0_{j}" for j in range(num_arms)],

ax/core/base_trial.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Any, TYPE_CHECKING
1616

1717
from ax.core.arm import Arm
18-
from ax.core.data import Data, sort_by_trial_index_and_arm_name
18+
from ax.core.data import Data
1919
from ax.core.evaluations_to_data import raw_evaluations_to_data
2020
from ax.core.generator_run import GeneratorRun, GeneratorRunType
2121
from ax.core.metric import Metric, MetricFetchResult
@@ -442,8 +442,6 @@ def fetch_data(self, metrics: list[Metric] | None = None, **kwargs: Any) -> Data
442442
data = Metric._unwrap_trial_data_multi(
443443
results=self.fetch_data_results(metrics=metrics, **kwargs)
444444
)
445-
if not data.has_step_column:
446-
data.full_df = sort_by_trial_index_and_arm_name(data.full_df)
447445

448446
return data
449447

ax/core/data.py

Lines changed: 108 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,36 @@
3838
MAP_KEY = "step"
3939

4040

41+
class DataRow:
42+
def __init__(
43+
self,
44+
trial_index: int,
45+
arm_name: str,
46+
metric_name: str,
47+
metric_signature: str,
48+
mean: float,
49+
se: float,
50+
step: float | None = None,
51+
start_time: int | None = None,
52+
end_time: int | None = None,
53+
n: int | None = None,
54+
) -> None:
55+
self.trial_index: int = trial_index
56+
self.arm_name: str = arm_name
57+
58+
self.metric_name: str = metric_name
59+
self.metric_signature: str = metric_signature
60+
61+
self.mean: float = mean
62+
self.se: float = se
63+
64+
self.step: float | None = step
65+
66+
self.start_time: int | None = start_time
67+
self.end_time: int | None = end_time
68+
self.n: int | None = n
69+
70+
4171
class Data(Base, SerializationMixin):
4272
"""Class storing numerical data for an experiment.
4373
@@ -101,8 +131,6 @@ class Data(Base, SerializationMixin):
101131
"start_time": pd.Timestamp,
102132
"end_time": pd.Timestamp,
103133
"n": int,
104-
"frac_nonnull": np.float64,
105-
"random_split": int,
106134
MAP_KEY: float,
107135
}
108136

@@ -115,16 +143,19 @@ class Data(Base, SerializationMixin):
115143
"metric_signature",
116144
]
117145

118-
full_df: pd.DataFrame
146+
_data_rows: list[DataRow]
119147

120148
def __init__(
121149
self,
150+
data_rows: Iterable[DataRow] | None = None,
122151
df: pd.DataFrame | None = None,
123152
_skip_ordering_and_validation: bool = False,
124153
) -> None:
125154
"""Initialize a ``Data`` object from the given DataFrame.
126155
127156
Args:
157+
data_rows: Iterable of DataRows. If provided, this will be used as the
158+
source of truth for Data, over df.
128159
df: DataFrame with underlying data, and required columns. Data must
129160
be unique at the level of ("trial_index", "arm_name",
130161
"metric_name"), plus "step" if a "step" column is present. A
@@ -135,32 +166,86 @@ def __init__(
135166
Intended only for use in `Data.filter`, where the contents
136167
of the DataFrame are known to be ordered and valid.
137168
"""
138-
if df is None:
139-
# Initialize with barebones DF with expected dtypes
140-
self.full_df = pd.DataFrame.from_dict(
169+
if data_rows is not None:
170+
if isinstance(data_rows, pd.DataFrame):
171+
raise ValueError(
172+
"data_rows must be an iterable of DataRows, not a DataFrame."
173+
)
174+
self._data_rows = [*data_rows]
175+
elif df is not None:
176+
# Unroll the df into a list of DataRows
177+
if missing_columns := self.REQUIRED_COLUMNS - {*df.columns}:
178+
raise ValueError(
179+
f"Dataframe must contain required columns {list(missing_columns)}."
180+
)
181+
182+
self._data_rows = [
183+
DataRow(
184+
trial_index=row["trial_index"],
185+
arm_name=row["arm_name"],
186+
metric_name=row["metric_name"],
187+
metric_signature=row["metric_signature"],
188+
mean=row["mean"],
189+
se=row["sem"],
190+
step=row.get(MAP_KEY),
191+
start_time=row.get("start_time"),
192+
end_time=row.get("end_time"),
193+
n=row.get("n"),
194+
)
195+
for _, row in df.iterrows()
196+
]
197+
else:
198+
self._data_rows = []
199+
200+
self._memo_df: pd.DataFrame | None = None
201+
self.has_step_column: bool = any(
202+
row.step is not None for row in self._data_rows
203+
)
204+
205+
@cached_property
206+
def full_df(self) -> pd.DataFrame:
207+
"""
208+
Convert the DataRows into a pandas DataFrame. If step, start_time, or end_time
209+
is None for all rows the column will be elided.
210+
"""
211+
if len(self._data_rows) == 0:
212+
return pd.DataFrame.from_dict(
141213
{
142214
col: pd.Series([], dtype=self.COLUMN_DATA_TYPES[col])
143215
for col in self.REQUIRED_COLUMNS
144216
}
145217
)
146-
elif _skip_ordering_and_validation:
147-
self.full_df = df
148-
else:
149-
columns = set(df.columns)
150-
missing_columns = self.REQUIRED_COLUMNS - columns
151-
if missing_columns:
152-
raise ValueError(
153-
f"Dataframe must contain required columns {list(missing_columns)}."
154-
)
155-
# Drop rows where every input is null. Since `dropna` can be slow, first
156-
# check trial index to see if dropping nulls might be needed.
157-
if df["trial_index"].isnull().any():
158-
df = df.dropna(axis=0, how="all", ignore_index=True)
159-
df = self._safecast_df(df=df)
160-
self.full_df = self._get_df_with_cols_in_expected_order(df=df)
161218

162-
self._memo_df: pd.DataFrame | None = None
163-
self.has_step_column: bool = MAP_KEY in self.full_df.columns
219+
# Detect whether any of the optional attributes are present and should be
220+
# included as columns in the full DataFrame.
221+
include_step = any(row.step is not None for row in self._data_rows)
222+
include_start_time = any(row.start_time is not None for row in self._data_rows)
223+
include_end_time = any(row.end_time is not None for row in self._data_rows)
224+
include_n = any(row.n is not None for row in self._data_rows)
225+
226+
records = [
227+
{
228+
"trial_index": row.trial_index,
229+
"arm_name": row.arm_name,
230+
"metric_name": row.metric_name,
231+
"metric_signature": row.metric_signature,
232+
"mean": row.mean,
233+
"sem": row.se,
234+
**({"step": row.step} if include_step else {}),
235+
**({"start_time": row.start_time} if include_start_time else {}),
236+
**({"end_time": row.end_time} if include_end_time else {}),
237+
**({"n": row.n} if include_n else {}),
238+
}
239+
for row in self._data_rows
240+
]
241+
242+
return sort_by_trial_index_and_arm_name(
243+
df=self._get_df_with_cols_in_expected_order(
244+
df=self._safecast_df(
245+
df=pd.DataFrame.from_records(records),
246+
),
247+
)
248+
)
164249

165250
@classmethod
166251
def _get_df_with_cols_in_expected_order(cls, df: pd.DataFrame) -> pd.DataFrame:

ax/core/tests/test_data.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,14 +187,6 @@ def test_from_multiple_with_generator(self) -> None:
187187
self.assertEqual(len(data.full_df), 2 * len(self.data_with_df.full_df))
188188
self.assertFalse(data.has_step_column)
189189

190-
def test_extra_columns(self) -> None:
191-
value = 3
192-
extra_col_df = self.df.assign(foo=value)
193-
data = Data(df=extra_col_df)
194-
self.assertIn("foo", data.full_df.columns)
195-
self.assertIn("foo", data.df.columns)
196-
self.assertTrue((data.full_df["foo"] == value).all())
197-
198190
def test_get_df_with_cols_in_expected_order(self) -> None:
199191
with self.subTest("Wrong order"):
200192
df = pd.DataFrame(columns=["mean", "trial_index", "hat"], data=[[0] * 3])

ax/core/tests/test_experiment.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ def test_fetch_and_store_data(self) -> None:
673673

674674
# Verify we do get the stored data if there are an unimplemented metrics.
675675
# Remove attached data for nonexistent metric.
676-
exp.data.full_df = exp.data.full_df.loc[lambda x: x["metric_name"] != "z"]
676+
exp.data = Data(df=exp.data.full_df.loc[lambda x: x["metric_name"] != "z"])
677677

678678
# Remove implemented metric that is `available_while_running`
679679
# (and therefore not pulled from cache).
@@ -685,7 +685,9 @@ def test_fetch_and_store_data(self) -> None:
685685
looked_up_df = looked_up_data.full_df
686686
self.assertFalse((looked_up_df["metric_name"] == "z").any())
687687
self.assertTrue(
688-
batch.fetch_data().full_df.equals(
688+
batch.fetch_data()
689+
.full_df.sort_values(["arm_name", "metric_name"], ignore_index=True)
690+
.equals(
689691
looked_up_df.loc[lambda x: (x["trial_index"] == 0)].sort_values(
690692
["arm_name", "metric_name"], ignore_index=True
691693
)

ax/plot/pareto_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def get_observed_pareto_frontiers(
207207
):
208208
# Make sure status quo is always included, for derelativization
209209
arm_names.append(experiment.status_quo.name)
210-
data = Data(data.df[data.df["arm_name"].isin(arm_names)])
210+
data = Data(df=data.df[data.df["arm_name"].isin(arm_names)])
211211
adapter = get_tensor_converter_adapter(experiment=experiment, data=data)
212212
pareto_observations = observed_pareto_frontier(adapter=adapter)
213213
# Convert to ParetoFrontierResults

ax/plot/scatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1731,7 +1731,7 @@ def tile_observations(
17311731
if data is None:
17321732
data = experiment.fetch_data()
17331733
if arm_names is not None:
1734-
data = Data(data.df[data.df["arm_name"].isin(arm_names)])
1734+
data = Data(df=data.df[data.df["arm_name"].isin(arm_names)])
17351735
m_ts = Generators.THOMPSON(
17361736
data=data,
17371737
search_space=experiment.search_space,

ax/plot/tests/test_fitted_scatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_fitted_scatter(self) -> None:
3333
model = Generators.BOTORCH_MODULAR(
3434
# Adapter kwargs
3535
experiment=exp,
36-
data=Data.from_multiple_data([data, Data(df)]),
36+
data=Data.from_multiple_data([data, Data(df=df)]),
3737
)
3838
# Assert that each type of plot can be constructed successfully
3939
scalarized_metric_config = [

ax/plot/tests/test_pareto_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_get_observed_pareto_frontiers(self) -> None:
107107
# For the check below, compute which arms are better than SQ
108108
df = experiment.fetch_data().df
109109
df["sem"] = np.nan
110-
data = Data(df)
110+
data = Data(df=df)
111111
sq_val = df[(df["arm_name"] == "status_quo") & (df["metric_name"] == "m1")][
112112
"mean"
113113
].values[0]

ax/storage/json_store/tests/test_json_store.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -705,10 +705,6 @@ def test_decode_map_data_backward_compatible(self) -> None:
705705
class_decoder_registry=CORE_CLASS_DECODER_REGISTRY,
706706
)
707707
self.assertEqual(len(map_data.full_df), 2)
708-
# Even though the "epoch" and "timestamps" columns have not been
709-
# renamed to "step", they are present
710-
self.assertEqual(map_data.full_df["epoch"].tolist(), [0.0, 1.0])
711-
self.assertEqual(map_data.full_df["timestamps"].tolist(), [3.0, 4.0])
712708
self.assertIsInstance(map_data, Data)
713709

714710
with self.subTest("Single map key"):
@@ -729,8 +725,8 @@ def test_decode_map_data_backward_compatible(self) -> None:
729725
decoder_registry=CORE_DECODER_REGISTRY,
730726
class_decoder_registry=CORE_CLASS_DECODER_REGISTRY,
731727
)
732-
self.assertIn("epoch", map_data.full_df.columns)
733-
self.assertEqual(map_data.full_df["epoch"].tolist(), [0.0, 1.0])
728+
self.assertEqual(len(map_data.full_df), 2)
729+
self.assertIsInstance(map_data, Data)
734730

735731
with self.subTest("No map key"):
736732
data_json = {

0 commit comments

Comments
 (0)