Skip to content

Commit 6f43f73

Browse files
committed
add tests for slab get slab regions
1 parent e0374e9 commit 6f43f73

File tree

1 file changed

+54
-17
lines changed

1 file changed

+54
-17
lines changed

tests/core/test_surface.py

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -294,24 +294,61 @@ def surface_area(s):
294294
assert surface_area(slab) == approx(surface_area(ouc))
295295
assert len(slab) >= len(ouc)
296296

297-
def test_get_slab_regions(self):
298-
# If a slab layer in the slab cell is not completely inside
299-
# the cell (noncontiguous), check that get_slab_regions will
300-
# be able to identify where the slab layers are located
297+
@staticmethod
298+
def _make_simple_slab(
299+
thickness: float,
300+
z_center: float = 0.5,
301+
n_sites: int = 5,
302+
) -> Structure:
303+
rng = np.random.default_rng(seed=0)
304+
305+
zmin = z_center - thickness / 2
306+
zmax = z_center + thickness / 2
307+
308+
zs = np.linspace(zmin, zmax, n_sites)
309+
xs = rng.random(n_sites) # uniform in [0,1)
310+
ys = rng.random(n_sites)
311+
312+
coords = [[x, y, z] for x, y, z in zip(xs, ys, zs, strict=True)]
313+
314+
struct = Structure(
315+
Lattice.cubic(10),
316+
["H"] * n_sites,
317+
coords,
318+
coords_are_cartesian=False,
319+
)
301320

302-
struct = self.get_structure("LiFePO4")
303-
slab_gen = SlabGenerator(struct, (0, 0, 1), 15, 15)
304-
slab = slab_gen.get_slabs()[0]
305-
slab.translate_sites([idx for idx, site in enumerate(slab)], [0, 0, -0.25])
306-
bottom_c, top_c = [], []
307-
for site in slab:
308-
if site.frac_coords[2] < 0.5:
309-
bottom_c.append(site.frac_coords[2])
310-
else:
311-
top_c.append(site.frac_coords[2])
312-
ranges = get_slab_regions(slab)
313-
assert tuple(ranges[0]) == (0, max(bottom_c))
314-
assert tuple(ranges[1]) == (min(top_c), 1)
321+
# Sanity checks
322+
assert xs.min() >= 0
323+
assert xs.max() < 1
324+
assert ys.min() >= 0
325+
assert ys.max() < 1
326+
327+
z_vals = struct.frac_coords[:, 2]
328+
assert np.allclose(z_vals, zs), "Structure unexpectedly altered z-values"
329+
330+
assert z_vals.max() - z_vals.min() == approx(thickness)
331+
332+
return struct
333+
334+
@pytest.mark.parametrize("z_center", [-0.5, 0.5, 1.5])
335+
def test_get_slab_regions_single_continuous(self, z_center):
336+
thickness = 0.2
337+
slab = self._make_simple_slab(thickness, z_center)
338+
regions = Slab.get_slab_regions(slab)
339+
340+
assert len(regions) == 1
341+
assert regions[0][0] == approx(z_center - thickness / 2)
342+
assert regions[0][1] == approx(z_center + thickness / 2)
343+
344+
def test_get_slab_regions_single_non_continuous(self):
345+
pass
346+
347+
def test_get_slab_regions_multiple(self):
348+
# Test two slab regions
349+
350+
# Test multiple slab regions
351+
pass
315352

316353
def test_as_dict(self):
317354
slabs = generate_all_slabs(

0 commit comments

Comments
 (0)