@@ -294,24 +294,61 @@ def surface_area(s):
294294 assert surface_area (slab ) == approx (surface_area (ouc ))
295295 assert len (slab ) >= len (ouc )
296296
297- def test_get_slab_regions (self ):
298- # If a slab layer in the slab cell is not completely inside
299- # the cell (noncontiguous), check that get_slab_regions will
300- # be able to identify where the slab layers are located
297+ @staticmethod
298+ def _make_simple_slab (
299+ thickness : float ,
300+ z_center : float = 0.5 ,
301+ n_sites : int = 5 ,
302+ ) -> Structure :
303+ rng = np .random .default_rng (seed = 0 )
304+
305+ zmin = z_center - thickness / 2
306+ zmax = z_center + thickness / 2
307+
308+ zs = np .linspace (zmin , zmax , n_sites )
309+ xs = rng .random (n_sites ) # uniform in [0,1)
310+ ys = rng .random (n_sites )
311+
312+ coords = [[x , y , z ] for x , y , z in zip (xs , ys , zs , strict = True )]
313+
314+ struct = Structure (
315+ Lattice .cubic (10 ),
316+ ["H" ] * n_sites ,
317+ coords ,
318+ coords_are_cartesian = False ,
319+ )
301320
302- struct = self .get_structure ("LiFePO4" )
303- slab_gen = SlabGenerator (struct , (0 , 0 , 1 ), 15 , 15 )
304- slab = slab_gen .get_slabs ()[0 ]
305- slab .translate_sites ([idx for idx , site in enumerate (slab )], [0 , 0 , - 0.25 ])
306- bottom_c , top_c = [], []
307- for site in slab :
308- if site .frac_coords [2 ] < 0.5 :
309- bottom_c .append (site .frac_coords [2 ])
310- else :
311- top_c .append (site .frac_coords [2 ])
312- ranges = get_slab_regions (slab )
313- assert tuple (ranges [0 ]) == (0 , max (bottom_c ))
314- assert tuple (ranges [1 ]) == (min (top_c ), 1 )
321+ # Sanity checks
322+ assert xs .min () >= 0
323+ assert xs .max () < 1
324+ assert ys .min () >= 0
325+ assert ys .max () < 1
326+
327+ z_vals = struct .frac_coords [:, 2 ]
328+ assert np .allclose (z_vals , zs ), "Structure unexpectedly altered z-values"
329+
330+ assert z_vals .max () - z_vals .min () == approx (thickness )
331+
332+ return struct
333+
334+ @pytest .mark .parametrize ("z_center" , [- 0.5 , 0.5 , 1.5 ])
335+ def test_get_slab_regions_single_continuous (self , z_center ):
336+ thickness = 0.2
337+ slab = self ._make_simple_slab (thickness , z_center )
338+ regions = Slab .get_slab_regions (slab )
339+
340+ assert len (regions ) == 1
341+ assert regions [0 ][0 ] == approx (z_center - thickness / 2 )
342+ assert regions [0 ][1 ] == approx (z_center + thickness / 2 )
343+
344+ def test_get_slab_regions_single_non_continuous (self ):
345+ pass
346+
347+ def test_get_slab_regions_multiple (self ):
348+ # Test two slab regions
349+
350+ # Test multiple slab regions
351+ pass
315352
316353 def test_as_dict (self ):
317354 slabs = generate_all_slabs (
0 commit comments