Skip to content

Commit 13ff42f

Browse files
mpolson64facebook-github-bot
authored andcommitted
Remove TestDataBase now that DataBase is gone (facebook#4772)
Summary: Moved these tests into TestData, since Data is the only data-related class in Ax. Reviewed By: saitcakmak Differential Revision: D90605845 Privacy Context Container: L1413903
1 parent 06ad954 commit 13ff42f

File tree

1 file changed

+16
-80
lines changed

1 file changed

+16
-80
lines changed

ax/core/tests/test_data.py

Lines changed: 16 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -119,29 +119,22 @@ def get_test_dataframe() -> pd.DataFrame:
119119
)
120120

121121

122-
class TestDataBase(TestCase):
123-
# Also run with has_step_column = True below
124-
has_step_column: bool = False
122+
class DataTest(TestCase):
123+
"""Tests for Data without a "step" column."""
125124

126125
def setUp(self) -> None:
127126
super().setUp()
128127
self.data_without_df = Data()
129-
df = get_test_dataframe()
130-
if not self.has_step_column:
131-
self.df = df
132-
self.data_with_df = Data(df=self.df)
133-
else:
134-
df_1 = df.copy().assign(**{MAP_KEY: 0})
135-
df_2 = df.copy().assign(**{MAP_KEY: 1})
136-
self.df = pd.concat((df_1, df_2))
137-
self.data_with_df = Data(df=self.df)
138-
128+
self.df = get_test_dataframe()
129+
self.data_with_df = Data(df=self.df)
139130
self.metric_name_to_signature = {"a": "a_signature", "b": "b_signature"}
140131

141132
def test_init(self) -> None:
133+
# Test equality
142134
self.assertEqual(self.data_without_df, self.data_without_df)
143135
self.assertEqual(self.data_with_df, self.data_with_df)
144136

137+
# Test accessing values
145138
df = self.data_with_df.df
146139
self.assertEqual(
147140
float(df[df["arm_name"] == "0_0"][df["metric_name"] == "a"]["mean"].item()),
@@ -152,7 +145,14 @@ def test_init(self) -> None:
152145
0.5,
153146
)
154147

155-
self.assertEqual(self.data_with_df.has_step_column, self.has_step_column)
148+
# Test has_step_column is False
149+
self.assertFalse(self.data_with_df.has_step_column)
150+
151+
# Test empty initialization
152+
empty = Data()
153+
self.assertTrue(empty.full_df.empty)
154+
self.assertEqual(set(empty.full_df.columns), empty.REQUIRED_COLUMNS)
155+
self.assertFalse(empty.has_step_column)
156156

157157
def test_clone(self) -> None:
158158
data = self.data_with_df
@@ -164,14 +164,9 @@ def test_clone(self) -> None:
164164
self.assertIsNot(data, data_clone)
165165
self.assertIsNot(data.df, data_clone.df)
166166
self.assertIsNone(data_clone._db_id)
167-
if self.has_step_column:
168-
self.assertIsNot(data.full_df, data_clone.full_df)
169-
self.assertTrue(data.full_df.equals(data_clone.full_df))
170167

171168
def test_BadData(self) -> None:
172169
data = {"bad_field": "0_0", "bad_field_2": {"x": 0, "y": "a"}}
173-
if self.has_step_column:
174-
data[MAP_KEY] = "0"
175170
df = pd.DataFrame([data])
176171
with self.assertRaisesRegex(
177172
ValueError, "Dataframe must contain required columns"
@@ -184,17 +179,13 @@ def test_EmptyData(self) -> None:
184179
self.assertTrue(df.empty)
185180
self.assertTrue(Data.from_multiple_data([]).df.empty)
186181

187-
if data.has_step_column:
188-
self.assertTrue(data.full_df.empty)
189-
expected_columns = Data.REQUIRED_COLUMNS.union({MAP_KEY})
190-
else:
191-
expected_columns = Data.REQUIRED_COLUMNS
182+
expected_columns = Data.REQUIRED_COLUMNS
192183
self.assertEqual(set(df.columns), expected_columns)
193184

194185
def test_from_multiple_with_generator(self) -> None:
195186
data = Data.from_multiple_data(self.data_with_df for _ in range(2))
196187
self.assertEqual(len(data.full_df), 2 * len(self.data_with_df.full_df))
197-
self.assertEqual(data.has_step_column, self.has_step_column)
188+
self.assertFalse(data.has_step_column)
198189

199190
def test_extra_columns(self) -> None:
200191
value = 3
@@ -235,26 +226,6 @@ def test_trial_indices(self) -> None:
235226
set(self.data_with_df.full_df["trial_index"].unique()),
236227
)
237228

238-
239-
class TestMapData(TestDataBase):
240-
has_step_column = True
241-
242-
243-
class DataTest(TestCase):
244-
"""Tests that are specific to Data without a "step" column."""
245-
246-
def setUp(self) -> None:
247-
super().setUp()
248-
self.df = get_test_dataframe()
249-
self.metric_name_to_signature = {"a": "a_signature", "b": "b_signature"}
250-
251-
def test_init(self) -> None:
252-
# Initialize empty
253-
empty = Data()
254-
self.assertTrue(empty.full_df.empty)
255-
self.assertEqual(set(empty.full_df.columns), empty.REQUIRED_COLUMNS)
256-
self.assertFalse(empty.has_step_column)
257-
258229
def test_repr(self) -> None:
259230
self.assertEqual(
260231
str(Data(df=self.df)),
@@ -263,13 +234,6 @@ def test_repr(self) -> None:
263234
with patch(f"{Data.__module__}.DF_REPR_MAX_LENGTH", 500):
264235
self.assertEqual(str(Data(df=self.df)), REPR_500)
265236

266-
def test_OtherClassInequality(self) -> None:
267-
class CustomData(Data):
268-
pass
269-
270-
data = CustomData(df=self.df)
271-
self.assertNotEqual(data, Data(self.df))
272-
273237
def test_from_multiple(self) -> None:
274238
with self.subTest("Combinining non-empty Data"):
275239
data = Data.from_multiple_data([Data(df=self.df), Data(df=self.df)])
@@ -279,34 +243,6 @@ def test_from_multiple(self) -> None:
279243
data = Data.from_multiple_data([Data(), Data()])
280244
self.assertEqual(data, Data())
281245

282-
with self.subTest("Different types"):
283-
284-
class CustomData(Data):
285-
pass
286-
287-
data = Data.from_multiple_data([CustomData(), CustomData()])
288-
self.assertEqual(data, Data())
289-
data = CustomData.from_multiple_data([Data(), CustomData()])
290-
self.assertEqual(data, CustomData())
291-
292-
def test_FromMultipleDataMismatchedTypes(self) -> None:
293-
# create two custom data types
294-
class CustomDataA(Data):
295-
pass
296-
297-
class CustomDataB(Data):
298-
pass
299-
300-
# Test using `Data.from_multiple_data` to combine non-Data types
301-
data = Data.from_multiple_data([CustomDataA(), CustomDataB()])
302-
self.assertEqual(data, Data())
303-
304-
# multiple non-empty types
305-
data_elt_A = CustomDataA(df=self.df)
306-
data_elt_B = CustomDataB(df=self.df)
307-
data = Data.from_multiple_data([data_elt_A, data_elt_B])
308-
self.assertEqual(len(data.full_df), 2 * len(self.df))
309-
310246
def test_filter(self) -> None:
311247
data = Data(df=self.df)
312248
# Test that filter throws when we provide metric names and metric signatures

0 commit comments

Comments
 (0)