From 5666eaff1293b6c9fa31c6f115e78e7df740a732 Mon Sep 17 00:00:00 2001
From: "Andrew S. Rosen" <asrosen93@gmail.com>
Date: Sat, 16 Mar 2024 23:45:35 -0400
Subject: [PATCH] Add an `ase_relax_job` for VASP (#1888)

Closes #1887.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
 src/quacc/recipes/onetep/core.py              |  3 +-
 src/quacc/recipes/vasp/_base.py               | 66 ++++++++++++++++++-
 src/quacc/recipes/vasp/core.py                | 56 +++++++++++++++-
 src/quacc/schemas/_aliases/vasp.py            |  6 +-
 .../vasp_recipes/mocked/test_vasp_recipes.py  | 17 ++++-
 5 files changed, 141 insertions(+), 7 deletions(-)

diff --git a/src/quacc/recipes/onetep/core.py b/src/quacc/recipes/onetep/core.py
index bc23d350fc..f8b9b6e119 100644
--- a/src/quacc/recipes/onetep/core.py
+++ b/src/quacc/recipes/onetep/core.py
@@ -69,6 +69,7 @@ def static_job(
 @job
 def ase_relax_job(
     atoms: Atoms,
+    relax_cell: bool = False,
     copy_files: SourceDirectory | dict[SourceDirectory, Filenames] | None = None,
     opt_params: dict[str, Any] | None = None,
     **calc_kwargs,
@@ -102,7 +103,7 @@ def ase_relax_job(
         {"keywords": {"write_forces": True, "forces_output_detail": "verbose"}},
     )
 
-    opt_defaults = {"optimizer": LBFGS}
+    opt_defaults = {"optimizer": LBFGS, "relax_cell": relax_cell}
 
     return base_opt_fn(
         atoms,
diff --git a/src/quacc/recipes/vasp/_base.py b/src/quacc/recipes/vasp/_base.py
index bd6d077a65..73c8097442 100644
--- a/src/quacc/recipes/vasp/_base.py
+++ b/src/quacc/recipes/vasp/_base.py
@@ -4,8 +4,10 @@
 
 from typing import TYPE_CHECKING
 
+from quacc.atoms.core import get_final_atoms_from_dyn
 from quacc.calculators.vasp import Vasp
-from quacc.runners.ase import run_calc
+from quacc.runners.ase import run_calc, run_opt
+from quacc.schemas.ase import summarize_opt_run
 from quacc.schemas.vasp import vasp_summarize_run
 from quacc.utils.dicts import recursive_dict_merge
 
@@ -52,7 +54,7 @@ def base_fn(
     Returns
     -------
     VaspSchema
-        Dictionary of results from [quacc.schemas.vasp.vasp_summarize_run][]
+        Dictionary of results
     """
     calc_flags = recursive_dict_merge(calc_defaults, calc_swaps)
 
@@ -64,3 +66,63 @@ def base_fn(
         report_mp_corrections=report_mp_corrections,
         additional_fields=additional_fields,
     )
+
+
+def base_opt_fn(
+    atoms: Atoms,
+    preset: str | None = None,
+    calc_defaults: dict[str, Any] | None = None,
+    calc_swaps: dict[str, Any] | None = None,
+    opt_defaults: dict[str, Any] | None = None,
+    opt_params: dict[str, Any] | None = None,
+    report_mp_corrections: bool = False,
+    additional_fields: dict[str, Any] | None = None,
+    copy_files: SourceDirectory | dict[SourceDirectory, Filenames] | None = None,
+) -> VaspSchema:
+    """
+    Base job function for VASP recipes.
+
+    Parameters
+    ----------
+    atoms
+        Atoms object
+    preset
+        Preset to use from `quacc.calculators.vasp.presets`.
+    calc_defaults
+        Default parameters for the recipe.
+    calc_swaps
+        Dictionary of custom kwargs for the Vasp calculator. Set a value to
+        `None` to remove a pre-existing key entirely. For a list of available
+        keys, refer to [quacc.calculators.vasp.vasp.Vasp][].
+    opt_defaults
+        Default arguments for the ASE optimizer.
+    opt_params
+        Dictionary of custom kwargs for [quacc.runners.ase.run_opt][]
+    report_mp_corrections
+        Whether to report the Materials Project corrections in the results.
+    additional_fields
+        Additional fields to supply to the summarizer.
+    copy_files
+        Files to copy (and decompress) from source to the runtime directory.
+
+    Returns
+    -------
+    VaspASESchema
+        Dictionary of results
+    """
+    calc_flags = recursive_dict_merge(calc_defaults, calc_swaps)
+    opt_flags = recursive_dict_merge(opt_defaults, opt_params)
+
+    atoms.calc = Vasp(atoms, preset=preset, **calc_flags)
+    dyn = run_opt(atoms, copy_files=copy_files, **opt_flags)
+
+    opt_run_summary = summarize_opt_run(dyn, additional_fields=additional_fields)
+
+    final_atoms = get_final_atoms_from_dyn(dyn)
+
+    vasp_summary = vasp_summarize_run(
+        final_atoms,
+        report_mp_corrections=report_mp_corrections,
+        additional_fields=additional_fields,
+    )
+    return recursive_dict_merge(vasp_summary, opt_run_summary)
diff --git a/src/quacc/recipes/vasp/core.py b/src/quacc/recipes/vasp/core.py
index 945dcdff5f..08a09b4a9b 100644
--- a/src/quacc/recipes/vasp/core.py
+++ b/src/quacc/recipes/vasp/core.py
@@ -10,14 +10,14 @@
 from pymatgen.io.vasp import Vasprun
 
 from quacc import flow, job
-from quacc.recipes.vasp._base import base_fn
+from quacc.recipes.vasp._base import base_fn, base_opt_fn
 
 if TYPE_CHECKING:
     from typing import Any
 
     from ase.atoms import Atoms
 
-    from quacc.schemas._aliases.vasp import DoubleRelaxSchema, VaspSchema
+    from quacc.schemas._aliases.vasp import DoubleRelaxSchema, VaspASESchema, VaspSchema
     from quacc.utils.files import Filenames, SourceDirectory
 
 
@@ -180,6 +180,58 @@ def double_relax_flow(
     return {"relax1": summary1, "relax2": summary2}
 
 
+@job
+def ase_relax_job(
+    atoms: Atoms,
+    preset: str | None = "BulkSet",
+    relax_cell: bool = True,
+    opt_params: dict[str, Any] | None = None,
+    copy_files: SourceDirectory | dict[SourceDirectory, Filenames] | None = None,
+    **calc_kwargs,
+) -> VaspASESchema:
+    """
+    Relax a structure.
+
+    Parameters
+    ----------
+    atoms
+        Atoms object
+    preset
+        Preset to use from `quacc.calculators.vasp.presets`.
+    relax_cell
+        True if a volume relaxation should be performed. False if only the positions
+        should be updated.
+    copy_files
+        Files to copy (and decompress) from source to the runtime directory.
+    **calc_kwargs
+        Custom kwargs for the Vasp calculator. Set a value to
+        `None` to remove a pre-existing key entirely. For a list of available
+        keys, refer to the [quacc.calculators.vasp.vasp.Vasp][] calculator.
+
+    Returns
+    -------
+    VaspASESchema
+        Dictionary of results. See the type-hint for the data structure.
+    """
+
+    calc_defaults = {
+        "lcharg": False,
+        "lwave": False,
+        "nsw": 0,
+    }
+    opt_defaults = {"relax_cell": relax_cell}
+    return base_opt_fn(
+        atoms,
+        preset=preset,
+        calc_defaults=calc_defaults,
+        calc_swaps=calc_kwargs,
+        opt_defaults=opt_defaults,
+        opt_params=opt_params,
+        additional_fields={"name": "VASP ASE Relax"},
+        copy_files=copy_files,
+    )
+
+
 @job
 def non_scf_job(
     atoms: Atoms,
diff --git a/src/quacc/schemas/_aliases/vasp.py b/src/quacc/schemas/_aliases/vasp.py
index 9e9e230481..6d23ab2b0a 100644
--- a/src/quacc/schemas/_aliases/vasp.py
+++ b/src/quacc/schemas/_aliases/vasp.py
@@ -4,7 +4,7 @@
 
 from typing import TypedDict
 
-from quacc.schemas._aliases.ase import RunSchema
+from quacc.schemas._aliases.ase import OptSchema, RunSchema
 from quacc.schemas._aliases.emmet import TaskDoc
 
 
@@ -72,3 +72,7 @@ class QMOFRelaxSchema(VaspSchema):
     position_relax_lowacc: VaspSchema
     volume_relax_lowacc: VaspSchema | None
     double_relax: VaspSchema
+
+
+class VaspASESchema(VaspSchema, OptSchema):
+    """Type hint associated with VASP relaxations run via ASE"""
diff --git a/tests/core/recipes/vasp_recipes/mocked/test_vasp_recipes.py b/tests/core/recipes/vasp_recipes/mocked/test_vasp_recipes.py
index 1ea18a0959..f66d1d66ea 100644
--- a/tests/core/recipes/vasp_recipes/mocked/test_vasp_recipes.py
+++ b/tests/core/recipes/vasp_recipes/mocked/test_vasp_recipes.py
@@ -8,7 +8,7 @@
 from quacc import SETTINGS
 from quacc.recipes.vasp.core import (
     double_relax_flow,
-    non_scf_job,
+    non_scf_job,ase_relax_job,
     relax_job,
     static_job,
 )
@@ -161,6 +161,21 @@ def test_doublerelax_flow(tmp_path, monkeypatch):
     assert double_relax_flow(atoms, relax1_kwargs={"kpts": [1, 1, 1]})
 
 
+
+def test_ase_relax_job(tmp_path, monkeypatch):
+    monkeypatch.chdir(tmp_path)
+
+    atoms = bulk("Al")
+
+    output = ase_relax_job(atoms)
+    assert output["nsites"] == len(atoms)
+    assert output["parameters"]["nsw"] == 0
+    assert output["parameters"]["lwave"] is False
+    assert output["parameters"]["lcharg"] is False
+    assert output["parameters"]["encut"] == 520
+    assert output["fmax"] == 0.01
+    assert len(output["trajectory_results"]) > 1
+
 def test_non_scf_job1(tmp_path, monkeypatch):
     monkeypatch.chdir(tmp_path)
     copy(MOCKED_DIR / "vasprun.xml.gz", tmp_path / "vasprun.xml.gz")