@@ -119,29 +119,22 @@ def get_test_dataframe() -> pd.DataFrame:
119119 )
120120
121121
122- class TestDataBase (TestCase ):
123- # Also run with has_step_column = True below
124- has_step_column : bool = False
122+ class DataTest (TestCase ):
123+ """Tests for Data without a "step" column."""
125124
126125 def setUp (self ) -> None :
127126 super ().setUp ()
128127 self .data_without_df = Data ()
129- df = get_test_dataframe ()
130- if not self .has_step_column :
131- self .df = df
132- self .data_with_df = Data (df = self .df )
133- else :
134- df_1 = df .copy ().assign (** {MAP_KEY : 0 })
135- df_2 = df .copy ().assign (** {MAP_KEY : 1 })
136- self .df = pd .concat ((df_1 , df_2 ))
137- self .data_with_df = Data (df = self .df )
138-
128+ self .df = get_test_dataframe ()
129+ self .data_with_df = Data (df = self .df )
139130 self .metric_name_to_signature = {"a" : "a_signature" , "b" : "b_signature" }
140131
141132 def test_init (self ) -> None :
133+ # Test equality
142134 self .assertEqual (self .data_without_df , self .data_without_df )
143135 self .assertEqual (self .data_with_df , self .data_with_df )
144136
137+ # Test accessing values
145138 df = self .data_with_df .df
146139 self .assertEqual (
147140 float (df [df ["arm_name" ] == "0_0" ][df ["metric_name" ] == "a" ]["mean" ].item ()),
@@ -152,7 +145,14 @@ def test_init(self) -> None:
152145 0.5 ,
153146 )
154147
155- self .assertEqual (self .data_with_df .has_step_column , self .has_step_column )
148+ # Test has_step_column is False
149+ self .assertFalse (self .data_with_df .has_step_column )
150+
151+ # Test empty initialization
152+ empty = Data ()
153+ self .assertTrue (empty .full_df .empty )
154+ self .assertEqual (set (empty .full_df .columns ), empty .REQUIRED_COLUMNS )
155+ self .assertFalse (empty .has_step_column )
156156
157157 def test_clone (self ) -> None :
158158 data = self .data_with_df
@@ -164,14 +164,9 @@ def test_clone(self) -> None:
164164 self .assertIsNot (data , data_clone )
165165 self .assertIsNot (data .df , data_clone .df )
166166 self .assertIsNone (data_clone ._db_id )
167- if self .has_step_column :
168- self .assertIsNot (data .full_df , data_clone .full_df )
169- self .assertTrue (data .full_df .equals (data_clone .full_df ))
170167
171168 def test_BadData (self ) -> None :
172169 data = {"bad_field" : "0_0" , "bad_field_2" : {"x" : 0 , "y" : "a" }}
173- if self .has_step_column :
174- data [MAP_KEY ] = "0"
175170 df = pd .DataFrame ([data ])
176171 with self .assertRaisesRegex (
177172 ValueError , "Dataframe must contain required columns"
@@ -184,17 +179,13 @@ def test_EmptyData(self) -> None:
184179 self .assertTrue (df .empty )
185180 self .assertTrue (Data .from_multiple_data ([]).df .empty )
186181
187- if data .has_step_column :
188- self .assertTrue (data .full_df .empty )
189- expected_columns = Data .REQUIRED_COLUMNS .union ({MAP_KEY })
190- else :
191- expected_columns = Data .REQUIRED_COLUMNS
182+ expected_columns = Data .REQUIRED_COLUMNS
192183 self .assertEqual (set (df .columns ), expected_columns )
193184
194185 def test_from_multiple_with_generator (self ) -> None :
195186 data = Data .from_multiple_data (self .data_with_df for _ in range (2 ))
196187 self .assertEqual (len (data .full_df ), 2 * len (self .data_with_df .full_df ))
197- self .assertEqual (data . has_step_column , self .has_step_column )
188+ self .assertFalse (data .has_step_column )
198189
199190 def test_extra_columns (self ) -> None :
200191 value = 3
@@ -235,26 +226,6 @@ def test_trial_indices(self) -> None:
235226 set (self .data_with_df .full_df ["trial_index" ].unique ()),
236227 )
237228
238-
239- class TestMapData (TestDataBase ):
240- has_step_column = True
241-
242-
243- class DataTest (TestCase ):
244- """Tests that are specific to Data without a "step" column."""
245-
246- def setUp (self ) -> None :
247- super ().setUp ()
248- self .df = get_test_dataframe ()
249- self .metric_name_to_signature = {"a" : "a_signature" , "b" : "b_signature" }
250-
251- def test_init (self ) -> None :
252- # Initialize empty
253- empty = Data ()
254- self .assertTrue (empty .full_df .empty )
255- self .assertEqual (set (empty .full_df .columns ), empty .REQUIRED_COLUMNS )
256- self .assertFalse (empty .has_step_column )
257-
258229 def test_repr (self ) -> None :
259230 self .assertEqual (
260231 str (Data (df = self .df )),
@@ -263,13 +234,6 @@ def test_repr(self) -> None:
263234 with patch (f"{ Data .__module__ } .DF_REPR_MAX_LENGTH" , 500 ):
264235 self .assertEqual (str (Data (df = self .df )), REPR_500 )
265236
266- def test_OtherClassInequality (self ) -> None :
267- class CustomData (Data ):
268- pass
269-
270- data = CustomData (df = self .df )
271- self .assertNotEqual (data , Data (self .df ))
272-
273237 def test_from_multiple (self ) -> None :
274238 with self .subTest ("Combinining non-empty Data" ):
275239 data = Data .from_multiple_data ([Data (df = self .df ), Data (df = self .df )])
@@ -279,34 +243,6 @@ def test_from_multiple(self) -> None:
279243 data = Data .from_multiple_data ([Data (), Data ()])
280244 self .assertEqual (data , Data ())
281245
282- with self .subTest ("Different types" ):
283-
284- class CustomData (Data ):
285- pass
286-
287- data = Data .from_multiple_data ([CustomData (), CustomData ()])
288- self .assertEqual (data , Data ())
289- data = CustomData .from_multiple_data ([Data (), CustomData ()])
290- self .assertEqual (data , CustomData ())
291-
292- def test_FromMultipleDataMismatchedTypes (self ) -> None :
293- # create two custom data types
294- class CustomDataA (Data ):
295- pass
296-
297- class CustomDataB (Data ):
298- pass
299-
300- # Test using `Data.from_multiple_data` to combine non-Data types
301- data = Data .from_multiple_data ([CustomDataA (), CustomDataB ()])
302- self .assertEqual (data , Data ())
303-
304- # multiple non-empty types
305- data_elt_A = CustomDataA (df = self .df )
306- data_elt_B = CustomDataB (df = self .df )
307- data = Data .from_multiple_data ([data_elt_A , data_elt_B ])
308- self .assertEqual (len (data .full_df ), 2 * len (self .df ))
309-
310246 def test_filter (self ) -> None :
311247 data = Data (df = self .df )
312248 # Test that filter throws when we provide metric names and metric signatures
0 commit comments