Skip to content

Commit

Permalink
Add more customize_funcs tests (#1863)
Browse files Browse the repository at this point in the history
Addresses #1852?
  • Loading branch information
Andrew-S-Rosen authored Mar 9, 2024
1 parent 2e9059a commit fae6393
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 18 deletions.
8 changes: 4 additions & 4 deletions src/quacc/wflow_tools/customizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,14 @@ def customize_funcs(

for i, func in enumerate(funcs):
func_ = deepcopy(func)
if params := parameters.get("all"):
func_ = update_parameters(func_, params)
if params := parameters.get(names[i]):
func_ = update_parameters(func_, params)
if "all" in decorators:
func_ = redecorate(func_, decorators["all"])
if names[i] in decorators:
func_ = redecorate(func_, decorators[names[i]])
if params := parameters.get("all"):
func_ = update_parameters(func_, params)
if params := parameters.get(names[i]):
func_ = update_parameters(func_, params)
updated_funcs.append(func_)

return updated_funcs[0] if len(updated_funcs) == 1 else tuple(updated_funcs)
10 changes: 10 additions & 0 deletions tests/core/wflow/test_customizers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from quacc import job
from quacc.wflow_tools.customizers import customize_funcs


Expand All @@ -22,6 +23,15 @@ def mult(a, b=1, c=2, d=2):
assert add_(1) == 5
assert mult_(1) == 8

add_, mult_ = customize_funcs(
["add", "mult"],
[add, mult],
parameters={"add": {"b": 2}, "mult": {"b": 2}},
decorators={"add": job(), "mult": job()},
)
assert add_(1) == 5
assert mult_(1) == 8

add_, mult_ = customize_funcs(
["add", "mult"], [add, mult], parameters={"all": {"b": 2}}
)
Expand Down
10 changes: 8 additions & 2 deletions tests/covalent/test_emt_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,22 @@

from ase.build import bulk

from quacc import job

# from quacc import flow
# from quacc.recipes.emt.core import relax_job
from quacc.recipes.emt.slabs import bulk_to_slabs_flow # skipcq: PYL-C0412


def test_functools(tmp_path, monkeypatch):
@pytest.mark.parametrize("job_decorators", [None, {"relax_job": job()}])
def test_functools(tmp_path, monkeypatch, job_decorators):
monkeypatch.chdir(tmp_path)
atoms = bulk("Cu")
dispatch_id = ct.dispatch(bulk_to_slabs_flow)(
atoms, job_params={"relax_job": {"opt_params": {"fmax": 0.1}}}, run_static=False
atoms,
run_static=False,
job_params={"relax_job": {"opt_params": {"fmax": 0.1}}},
job_decorators=job_decorators,
)
output = ct.get_result(dispatch_id, wait=True)
assert output.status == "COMPLETED"
Expand Down
10 changes: 7 additions & 3 deletions tests/dask/test_emt_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@
from ase.build import bulk
from dask.distributed import get_client

from quacc import flow
from quacc import flow, job
from quacc.recipes.emt.core import relax_job # skipcq: PYL-C0412
from quacc.recipes.emt.slabs import bulk_to_slabs_flow # skipcq: PYL-C0412

client = get_client()


def test_functools(tmp_path, monkeypatch):
@pytest.mark.parametrize("job_decorators", [None, {"relax_job": job()}])
def test_functools(tmp_path, monkeypatch, job_decorators):
monkeypatch.chdir(tmp_path)
atoms = bulk("Cu")
delayed = bulk_to_slabs_flow(
atoms, job_params={"relax_job": {"opt_params": {"fmax": 0.1}}}, run_static=False
atoms,
run_static=False,
job_params={"relax_job": {"opt_params": {"fmax": 0.1}}},
job_decorators=job_decorators,
)
result = client.compute(delayed).result()
assert len(result) == 4
Expand Down
10 changes: 7 additions & 3 deletions tests/parsl/test_emt_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ase.build import bulk

from quacc import SETTINGS
from quacc import SETTINGS, job
from quacc.recipes.emt.core import relax_job # skipcq: PYL-C0412
from quacc.recipes.emt.slabs import bulk_to_slabs_flow # skipcq: PYL-C0412

Expand All @@ -15,14 +15,18 @@


@pytest.mark.parametrize("chdir", [True, False])
def test_functools(tmp_path, monkeypatch, chdir):
@pytest.mark.parametrize("job_decorators", [None, {"relax_job": job()}])
def test_functools(tmp_path, monkeypatch, chdir, job_decorators):
monkeypatch.chdir(tmp_path)

SETTINGS.CHDIR = chdir

atoms = bulk("Cu")
result = bulk_to_slabs_flow(
atoms, job_params={"relax_job": {"opt_params": {"fmax": 0.1}}}, run_static=False
atoms,
run_static=False,
job_params={"relax_job": {"opt_params": {"fmax": 0.1}}},
job_decorators=job_decorators,
).result()
assert len(result) == 4
assert "atoms" in result[-1]
Expand Down
10 changes: 7 additions & 3 deletions tests/prefect/test_emt_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@

from ase.build import bulk

from quacc import flow
from quacc import flow, job
from quacc.recipes.emt.core import relax_job
from quacc.recipes.emt.slabs import bulk_to_slabs_flow # skipcq: PYL-C0412


def test_functools(tmp_path, monkeypatch):
@pytest.mark.parametrize("job_decorators", [None, {"relax_job": job()}])
def test_functools(tmp_path, monkeypatch, job_decorators):
monkeypatch.chdir(tmp_path)
atoms = bulk("Cu")
output = bulk_to_slabs_flow(
atoms, job_params={"relax_job": {"opt_params": {"fmax": 0.1}}}, run_static=False
atoms,
run_static=False,
job_params={"relax_job": {"opt_params": {"fmax": 0.1}}},
job_decorators=job_decorators,
)
result = [future.result() for future in output]
assert len(result) == 4
Expand Down
8 changes: 5 additions & 3 deletions tests/redun/test_emt_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,21 @@ def scheduler():
return redun.Scheduler()


from quacc import flow
from quacc import flow, job
from quacc.recipes.emt.core import relax_job
from quacc.recipes.emt.slabs import bulk_to_slabs_flow # skipcq: PYL-C0412


def test_functools(tmp_path, monkeypatch, scheduler):
@pytest.mark.parametrize("job_decorators", [None, {"relax_job": job()}])
def test_functools(tmp_path, monkeypatch, scheduler, job_decorators):
monkeypatch.chdir(tmp_path)
atoms = bulk("Cu")
result = scheduler.run(
bulk_to_slabs_flow(
atoms,
job_params={"relax_job": {"opt_params": {"fmax": 0.1}}},
run_static=False,
job_params={"relax_job": {"opt_params": {"fmax": 0.1}}},
job_decorators=job_decorators,
)
)
assert len(result) == 4
Expand Down

0 comments on commit fae6393

Please sign in to comment.