Skip to content

Commit 935348e

Browse files
committed
add tests for slab get slab regions
1 parent e0374e9 commit 935348e

File tree

2 files changed

+63
-28
lines changed

2 files changed

+63
-28
lines changed

src/pymatgen/core/surface.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -744,15 +744,10 @@ def center_slab(slab: Structure) -> Structure:
744744
This makes it easier to find surface sites and apply
745745
operations like doping.
746746
747-
There are two possible cases:
748-
1. When the slab region is completely positioned between
749-
two vacuum layers in the cell but is not centered, we simply
750-
shift the slab to the center along z-axis.
751-
2. If the slab completely resides outside the cell either
752-
from the bottom or the top, we iterate through all sites that
753-
spill over and shift all sites such that it is now
754-
on the other side. An edge case being, either the top
755-
of the slab is at z = 0 or the bottom is at z = 1.
747+
TODOs:
748+
- This assume there're only one or two slab regions, but I
749+
guess it's possible that there might be more (maybe a warning in this case)?
750+
- This doesn't work if site is outside the cell.
756751
757752
Args:
758753
slab (Structure): The slab to center.

tests/core/test_surface.py

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -294,25 +294,6 @@ def surface_area(s):
294294
assert surface_area(slab) == approx(surface_area(ouc))
295295
assert len(slab) >= len(ouc)
296296

297-
def test_get_slab_regions(self):
298-
# If a slab layer in the slab cell is not completely inside
299-
# the cell (noncontiguous), check that get_slab_regions will
300-
# be able to identify where the slab layers are located
301-
302-
struct = self.get_structure("LiFePO4")
303-
slab_gen = SlabGenerator(struct, (0, 0, 1), 15, 15)
304-
slab = slab_gen.get_slabs()[0]
305-
slab.translate_sites([idx for idx, site in enumerate(slab)], [0, 0, -0.25])
306-
bottom_c, top_c = [], []
307-
for site in slab:
308-
if site.frac_coords[2] < 0.5:
309-
bottom_c.append(site.frac_coords[2])
310-
else:
311-
top_c.append(site.frac_coords[2])
312-
ranges = get_slab_regions(slab)
313-
assert tuple(ranges[0]) == (0, max(bottom_c))
314-
assert tuple(ranges[1]) == (min(top_c), 1)
315-
316297
def test_as_dict(self):
317298
slabs = generate_all_slabs(
318299
self.ti,
@@ -346,6 +327,65 @@ def test_as_dict(self):
346327
assert slab == Slab.from_dict(d)
347328

348329

330+
class TestSlabHelpers:
331+
# TODO: this should be merged into `TestSlab`
332+
@staticmethod
333+
def _make_simple_slab(
334+
thickness: float,
335+
z_center: float = 0.5,
336+
n_sites: int = 5,
337+
) -> Structure:
338+
rng = np.random.default_rng(seed=0)
339+
340+
zmin = z_center - thickness / 2
341+
zmax = z_center + thickness / 2
342+
343+
zs = np.linspace(zmin, zmax, n_sites)
344+
xs = rng.random(n_sites) # uniform in [0,1)
345+
ys = rng.random(n_sites)
346+
347+
coords = [[x, y, z] for x, y, z in zip(xs, ys, zs, strict=True)]
348+
349+
struct = Structure(
350+
Lattice.cubic(10),
351+
["H"] * n_sites,
352+
coords,
353+
coords_are_cartesian=False,
354+
)
355+
356+
# Sanity checks
357+
assert xs.min() >= 0
358+
assert xs.max() < 1
359+
assert ys.min() >= 0
360+
assert ys.max() < 1
361+
362+
z_vals = struct.frac_coords[:, 2]
363+
assert np.allclose(z_vals, zs), "Structure unexpectedly altered z-values"
364+
365+
assert z_vals.max() - z_vals.min() == approx(thickness)
366+
367+
return struct
368+
369+
@pytest.mark.parametrize("z_center", [-0.5, 0.5, 1.5])
370+
def test_get_slab_regions_single_continuous(self, z_center):
371+
thickness = 0.2
372+
slab = self._make_simple_slab(thickness, z_center)
373+
regions = get_slab_regions(slab)
374+
375+
assert len(regions) == 1
376+
assert regions[0][0] == approx(z_center - thickness / 2)
377+
assert regions[0][1] == approx(z_center + thickness / 2)
378+
379+
def test_get_slab_regions_single_non_continuous(self):
380+
pass
381+
382+
def test_get_slab_regions_multiple(self):
383+
# Test two slab regions
384+
385+
# Test multiple slab regions
386+
pass
387+
388+
349389
class TestSlabGenerator(MatSciTest):
350390
def setup_method(self):
351391
lattice = Lattice.cubic(3.010)

0 commit comments

Comments
 (0)