Skip to content

Commit 2616bcc

Browse files
mpolson64facebook-github-bot
authored andcommitted
Back Data with DataRow object (facebook#4773)
Summary: NOTE: This is much slower than the implementation which is backed by a dataframe. For clarity, Ive put this naive implementation up as its own diff and the next diff hunts for speedups. Creates new source of truth for Data: the DataRow. The df is now a cached property which is dynamically generated based on these rows. In the future, these will become a Base object in SQLAlchemy st. Data will have a SQLAlchemy relationship to a list of DataRows which live in their own table. RFC: 1. Im renaming sem -> se here (but keeping sem in the df for now, since this could be an incredibly involved cleanup). Do we have alignment that this is a positive change? If so I can either start of backlog the cleanup across the codebase. cc Balandat who Ive talked about this with a while back. 2. This removes the ability for Data to contain arbitrary columns, which was added in D83682740 and afaik unused. Arbitrary new columns would not be compatible with the new storage setup (it was easy in the old setup which is why we added it), and I think we should take a careful look at how to store contextual data in the future in a structured way. Differential Revision: D90605846
1 parent 27a59b8 commit 2616bcc

File tree

12 files changed

+135
-46
lines changed

12 files changed

+135
-46
lines changed

ax/analysis/plotly/tests/test_marginal_effects.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def setUp(self) -> None:
4343
self.experiment.trials[i].mark_running(no_runner_required=True)
4444
self.experiment.attach_data(
4545
Data(
46-
pd.DataFrame(
46+
df=pd.DataFrame(
4747
{
4848
"trial_index": [i] * num_arms,
4949
"arm_name": [f"0_{j}" for j in range(num_arms)],

ax/core/base_trial.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Any, TYPE_CHECKING
1616

1717
from ax.core.arm import Arm
18-
from ax.core.data import Data, sort_by_trial_index_and_arm_name
18+
from ax.core.data import Data
1919
from ax.core.evaluations_to_data import raw_evaluations_to_data
2020
from ax.core.generator_run import GeneratorRun, GeneratorRunType
2121
from ax.core.metric import Metric, MetricFetchResult
@@ -442,8 +442,6 @@ def fetch_data(self, metrics: list[Metric] | None = None, **kwargs: Any) -> Data
442442
data = Metric._unwrap_trial_data_multi(
443443
results=self.fetch_data_results(metrics=metrics, **kwargs)
444444
)
445-
if not data.has_step_column:
446-
data.full_df = sort_by_trial_index_and_arm_name(data.full_df)
447445

448446
return data
449447

ax/core/data.py

Lines changed: 121 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,51 @@
3838
MAP_KEY = "step"
3939

4040

41+
class DataRow:
42+
trial_index: int
43+
arm_name: str
44+
45+
metric_name: str
46+
metric_signature: str
47+
48+
mean: float
49+
se: float
50+
51+
step: float | None
52+
53+
start_time: int | None
54+
end_time: int | None
55+
n: int | None
56+
57+
def __init__(
58+
self,
59+
trial_index: int,
60+
arm_name: str,
61+
metric_name: str,
62+
metric_signature: str,
63+
mean: float,
64+
se: float,
65+
step: float | None = None,
66+
start_time: int | None = None,
67+
end_time: int | None = None,
68+
n: int | None = None,
69+
) -> None:
70+
self.trial_index = trial_index
71+
self.arm_name = arm_name
72+
73+
self.metric_name = metric_name
74+
self.metric_signature = metric_signature
75+
76+
self.mean = mean
77+
self.se = se
78+
79+
self.step = step
80+
81+
self.start_time = start_time
82+
self.end_time = end_time
83+
self.n = n
84+
85+
4186
class Data(Base, SerializationMixin):
4287
"""Class storing numerical data for an experiment.
4388
@@ -115,16 +160,19 @@ class Data(Base, SerializationMixin):
115160
"metric_signature",
116161
]
117162

118-
full_df: pd.DataFrame
163+
_data_rows: list[DataRow]
119164

120165
def __init__(
121166
self,
167+
data_rows: Iterable[DataRow] | None = None,
122168
df: pd.DataFrame | None = None,
123169
_skip_ordering_and_validation: bool = False,
124170
) -> None:
125171
"""Initialize a ``Data`` object from the given DataFrame.
126172
127173
Args:
174+
data_rows: Iterable of DataRows. If provided, this will be used as the
175+
source of truth for Data, over df.
128176
df: DataFrame with underlying data, and required columns. Data must
129177
be unique at the level of ("trial_index", "arm_name",
130178
"metric_name"), plus "step" if a "step" column is present. A
@@ -135,32 +183,84 @@ def __init__(
135183
Intended only for use in `Data.filter`, where the contents
136184
of the DataFrame are known to be ordered and valid.
137185
"""
138-
if df is None:
139-
# Initialize with barebones DF with expected dtypes
140-
self.full_df = pd.DataFrame.from_dict(
186+
if data_rows is not None:
187+
if isinstance(data_rows, pd.DataFrame):
188+
raise ValueError(
189+
"data_rows must be an iterable of DataRows, not a DataFrame."
190+
)
191+
self._data_rows = [*data_rows]
192+
elif df is not None:
193+
# Unroll the df into a list of DataRows
194+
if missing_columns := self.REQUIRED_COLUMNS - {*df.columns}:
195+
raise ValueError(
196+
f"Dataframe must contain required columns {list(missing_columns)}."
197+
)
198+
199+
self._data_rows = [
200+
DataRow(
201+
trial_index=row["trial_index"],
202+
arm_name=row["arm_name"],
203+
metric_name=row["metric_name"],
204+
metric_signature=row["metric_signature"],
205+
mean=row["mean"],
206+
se=row["sem"],
207+
step=row.get(MAP_KEY),
208+
start_time=row.get("start_time"),
209+
end_time=row.get("end_time"),
210+
n=row.get("n"),
211+
)
212+
for _, row in df.iterrows()
213+
]
214+
else:
215+
self._data_rows = []
216+
217+
self._memo_df: pd.DataFrame | None = None
218+
self.has_step_column: bool = any(
219+
row.step is not None for row in self._data_rows
220+
)
221+
222+
@cached_property
223+
def full_df(self) -> pd.DataFrame:
224+
"""
225+
Convert the DataRows into a pandas DataFrame. If step, start_time, or end_time
226+
is None for all rows the column will be elided.
227+
"""
228+
if len(self._data_rows) == 0:
229+
return pd.DataFrame.from_dict(
141230
{
142231
col: pd.Series([], dtype=self.COLUMN_DATA_TYPES[col])
143232
for col in self.REQUIRED_COLUMNS
144233
}
145234
)
146-
elif _skip_ordering_and_validation:
147-
self.full_df = df
148-
else:
149-
columns = set(df.columns)
150-
missing_columns = self.REQUIRED_COLUMNS - columns
151-
if missing_columns:
152-
raise ValueError(
153-
f"Dataframe must contain required columns {list(missing_columns)}."
154-
)
155-
# Drop rows where every input is null. Since `dropna` can be slow, first
156-
# check trial index to see if dropping nulls might be needed.
157-
if df["trial_index"].isnull().any():
158-
df = df.dropna(axis=0, how="all", ignore_index=True)
159-
df = self._safecast_df(df=df)
160-
self.full_df = self._get_df_with_cols_in_expected_order(df=df)
161235

162-
self._memo_df: pd.DataFrame | None = None
163-
self.has_step_column: bool = MAP_KEY in self.full_df.columns
236+
# Detect whether any of the optional attributes are present and should be
237+
# included as columns in the full DataFrame.
238+
include_step = any(row.step is not None for row in self._data_rows)
239+
include_start_time = any(row.start_time is not None for row in self._data_rows)
240+
include_end_time = any(row.end_time is not None for row in self._data_rows)
241+
include_n = any(row.n is not None for row in self._data_rows)
242+
243+
records = [
244+
{
245+
"trial_index": row.trial_index,
246+
"arm_name": row.arm_name,
247+
"metric_name": row.metric_name,
248+
"metric_signature": row.metric_signature,
249+
"mean": row.mean,
250+
"sem": row.se,
251+
**({"step": row.step} if include_step else {}),
252+
**({"start_time": row.start_time} if include_start_time else {}),
253+
**({"end_time": row.end_time} if include_end_time else {}),
254+
**({"n": row.n} if include_n else {}),
255+
}
256+
for row in self._data_rows
257+
]
258+
259+
return self._get_df_with_cols_in_expected_order(
260+
df=self._safecast_df(
261+
df=pd.DataFrame.from_records(records),
262+
),
263+
)
164264

165265
@classmethod
166266
def _get_df_with_cols_in_expected_order(cls, df: pd.DataFrame) -> pd.DataFrame:

ax/core/tests/test_data.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,14 +187,6 @@ def test_from_multiple_with_generator(self) -> None:
187187
self.assertEqual(len(data.full_df), 2 * len(self.data_with_df.full_df))
188188
self.assertFalse(data.has_step_column)
189189

190-
def test_extra_columns(self) -> None:
191-
value = 3
192-
extra_col_df = self.df.assign(foo=value)
193-
data = Data(df=extra_col_df)
194-
self.assertIn("foo", data.full_df.columns)
195-
self.assertIn("foo", data.df.columns)
196-
self.assertTrue((data.full_df["foo"] == value).all())
197-
198190
def test_get_df_with_cols_in_expected_order(self) -> None:
199191
with self.subTest("Wrong order"):
200192
df = pd.DataFrame(columns=["mean", "trial_index", "hat"], data=[[0] * 3])

ax/core/tests/test_experiment.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ def test_fetch_and_store_data(self) -> None:
673673

674674
# Verify we do get the stored data if there are an unimplemented metrics.
675675
# Remove attached data for nonexistent metric.
676-
exp.data.full_df = exp.data.full_df.loc[lambda x: x["metric_name"] != "z"]
676+
exp.data = Data(df=exp.data.full_df.loc[lambda x: x["metric_name"] != "z"])
677677

678678
# Remove implemented metric that is `available_while_running`
679679
# (and therefore not pulled from cache).
@@ -685,7 +685,9 @@ def test_fetch_and_store_data(self) -> None:
685685
looked_up_df = looked_up_data.full_df
686686
self.assertFalse((looked_up_df["metric_name"] == "z").any())
687687
self.assertTrue(
688-
batch.fetch_data().full_df.equals(
688+
batch.fetch_data()
689+
.full_df.sort_values(["arm_name", "metric_name"], ignore_index=True)
690+
.equals(
689691
looked_up_df.loc[lambda x: (x["trial_index"] == 0)].sort_values(
690692
["arm_name", "metric_name"], ignore_index=True
691693
)

ax/plot/pareto_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def get_observed_pareto_frontiers(
207207
):
208208
# Make sure status quo is always included, for derelativization
209209
arm_names.append(experiment.status_quo.name)
210-
data = Data(data.df[data.df["arm_name"].isin(arm_names)])
210+
data = Data(df=data.df[data.df["arm_name"].isin(arm_names)])
211211
adapter = get_tensor_converter_adapter(experiment=experiment, data=data)
212212
pareto_observations = observed_pareto_frontier(adapter=adapter)
213213
# Convert to ParetoFrontierResults

ax/plot/scatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1731,7 +1731,7 @@ def tile_observations(
17311731
if data is None:
17321732
data = experiment.fetch_data()
17331733
if arm_names is not None:
1734-
data = Data(data.df[data.df["arm_name"].isin(arm_names)])
1734+
data = Data(df=data.df[data.df["arm_name"].isin(arm_names)])
17351735
m_ts = Generators.THOMPSON(
17361736
data=data,
17371737
search_space=experiment.search_space,

ax/plot/tests/test_fitted_scatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_fitted_scatter(self) -> None:
3333
model = Generators.BOTORCH_MODULAR(
3434
# Adapter kwargs
3535
experiment=exp,
36-
data=Data.from_multiple_data([data, Data(df)]),
36+
data=Data.from_multiple_data([data, Data(df=df)]),
3737
)
3838
# Assert that each type of plot can be constructed successfully
3939
scalarized_metric_config = [

ax/plot/tests/test_pareto_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_get_observed_pareto_frontiers(self) -> None:
107107
# For the check below, compute which arms are better than SQ
108108
df = experiment.fetch_data().df
109109
df["sem"] = np.nan
110-
data = Data(df)
110+
data = Data(df=df)
111111
sq_val = df[(df["arm_name"] == "status_quo") & (df["metric_name"] == "m1")][
112112
"mean"
113113
].values[0]

ax/storage/json_store/tests/test_json_store.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -705,10 +705,6 @@ def test_decode_map_data_backward_compatible(self) -> None:
705705
class_decoder_registry=CORE_CLASS_DECODER_REGISTRY,
706706
)
707707
self.assertEqual(len(map_data.full_df), 2)
708-
# Even though the "epoch" and "timestamps" columns have not been
709-
# renamed to "step", they are present
710-
self.assertEqual(map_data.full_df["epoch"].tolist(), [0.0, 1.0])
711-
self.assertEqual(map_data.full_df["timestamps"].tolist(), [3.0, 4.0])
712708
self.assertIsInstance(map_data, Data)
713709

714710
with self.subTest("Single map key"):
@@ -729,8 +725,8 @@ def test_decode_map_data_backward_compatible(self) -> None:
729725
decoder_registry=CORE_DECODER_REGISTRY,
730726
class_decoder_registry=CORE_CLASS_DECODER_REGISTRY,
731727
)
732-
self.assertIn("epoch", map_data.full_df.columns)
733-
self.assertEqual(map_data.full_df["epoch"].tolist(), [0.0, 1.0])
728+
self.assertEqual(len(map_data.full_df), 2)
729+
self.assertIsInstance(map_data, Data)
734730

735731
with self.subTest("No map key"):
736732
data_json = {

0 commit comments

Comments
 (0)