Skip to content

Commit f5c5922

Browse files
committed
add tests for slab get slab regions
1 parent f564944 commit f5c5922

File tree

1 file changed

+54
-17
lines changed

1 file changed

+54
-17
lines changed

tests/core/test_surface.py

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -293,24 +293,61 @@ def surface_area(s):
293293
assert surface_area(slab) == approx(surface_area(ouc))
294294
assert len(slab) >= len(ouc)
295295

296-
def test_get_slab_regions(self):
297-
# If a slab layer in the slab cell is not completely inside
298-
# the cell (noncontiguous), check that get_slab_regions will
299-
# be able to identify where the slab layers are located
296+
@staticmethod
297+
def _make_simple_slab(
298+
thickness: float,
299+
z_center: float = 0.5,
300+
n_sites: int = 5,
301+
) -> Structure:
302+
rng = np.random.default_rng(seed=0)
303+
304+
zmin = z_center - thickness / 2
305+
zmax = z_center + thickness / 2
306+
307+
zs = np.linspace(zmin, zmax, n_sites)
308+
xs = rng.random(n_sites) # uniform in [0,1)
309+
ys = rng.random(n_sites)
310+
311+
coords = [[x, y, z] for x, y, z in zip(xs, ys, zs, strict=True)]
312+
313+
struct = Structure(
314+
Lattice.cubic(10),
315+
["H"] * n_sites,
316+
coords,
317+
coords_are_cartesian=False,
318+
)
300319

301-
struct = self.get_structure("LiFePO4")
302-
slab_gen = SlabGenerator(struct, (0, 0, 1), 15, 15)
303-
slab = slab_gen.get_slabs()[0]
304-
slab.translate_sites([idx for idx, site in enumerate(slab)], [0, 0, -0.25])
305-
bottom_c, top_c = [], []
306-
for site in slab:
307-
if site.frac_coords[2] < 0.5:
308-
bottom_c.append(site.frac_coords[2])
309-
else:
310-
top_c.append(site.frac_coords[2])
311-
ranges = Slab.get_slab_regions(slab)
312-
assert tuple(ranges[0]) == (0, max(bottom_c))
313-
assert tuple(ranges[1]) == (min(top_c), 1)
320+
# Sanity checks
321+
assert xs.min() >= 0
322+
assert xs.max() < 1
323+
assert ys.min() >= 0
324+
assert ys.max() < 1
325+
326+
z_vals = struct.frac_coords[:, 2]
327+
assert np.allclose(z_vals, zs), "Structure unexpectedly altered z-values"
328+
329+
assert z_vals.max() - z_vals.min() == approx(thickness)
330+
331+
return struct
332+
333+
@pytest.mark.parametrize("z_center", [-0.5, 0.5, 1.5])
334+
def test_get_slab_regions_single_continuous(self, z_center):
335+
thickness = 0.2
336+
slab = self._make_simple_slab(thickness, z_center)
337+
regions = Slab.get_slab_regions(slab)
338+
339+
assert len(regions) == 1
340+
assert regions[0][0] == approx(z_center - thickness / 2)
341+
assert regions[0][1] == approx(z_center + thickness / 2)
342+
343+
def test_get_slab_regions_single_non_continuous(self):
344+
pass
345+
346+
def test_get_slab_regions_multiple(self):
347+
# Test two slab regions
348+
349+
# Test multiple slab regions
350+
pass
314351

315352
def test_as_dict(self):
316353
slabs = generate_all_slabs(

0 commit comments

Comments
 (0)