Skip to content

Commit 6a3e176

Browse files
RastislavTuranyiTuranyi
andauthored
Add progress bar to phonon calculations (#319)
* Add rich progress bar to Phonons.calc_force_constants * Add rich as dependency --------- Co-authored-by: Turanyi <[email protected]>
1 parent ee30bf6 commit 6a3e176

File tree

4 files changed

+64
-2
lines changed

4 files changed

+64
-2
lines changed

janus_core/calculations/phonons.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
PathLike,
2121
PhononCalcs,
2222
)
23-
from janus_core.helpers.utils import none_to_dict, write_table
23+
from janus_core.helpers.utils import none_to_dict, track_progress, write_table
2424

2525

2626
class Phonons(BaseCalculation):
@@ -90,6 +90,8 @@ class Phonons(BaseCalculation):
9090
file_prefix : Optional[PathLike]
9191
Prefix for output filenames. Default is inferred from chemical formula of the
9292
structure.
93+
enable_progress_bar : bool
94+
Whether to show a progress bar during phonon calculations. Default is False.
9395
9496
Attributes
9597
----------
@@ -152,6 +154,7 @@ def __init__(
152154
write_results: bool = True,
153155
write_full: bool = True,
154156
file_prefix: Optional[PathLike] = None,
157+
enable_progress_bar: bool = False,
155158
) -> None:
156159
"""
157160
Initialise Phonons class.
@@ -219,6 +222,8 @@ def __init__(
219222
file_prefix : Optional[PathLike]
220223
Prefix for output filenames. Default is inferred from structure name, or
221224
chemical formula of the structure.
225+
enable_progress_bar : bool
226+
Whether to show a progress bar during phonon calculations. Default is False.
222227
"""
223228
(read_kwargs, minimize_kwargs) = none_to_dict((read_kwargs, minimize_kwargs))
224229

@@ -235,6 +240,7 @@ def __init__(
235240
self.plot_to_file = plot_to_file
236241
self.write_results = write_results
237242
self.write_full = write_full
243+
self.enable_progress_bar = enable_progress_bar
238244

239245
# Ensure supercell is a valid list
240246
self.supercell = [supercell] * 3 if isinstance(supercell, int) else supercell
@@ -363,6 +369,11 @@ def calc_force_constants(
363369
phonon.generate_displacements(distance=self.displacement)
364370
disp_supercells = phonon.supercells_with_displacements
365371

372+
if self.enable_progress_bar:
373+
disp_supercells = track_progress(
374+
disp_supercells, "Computing displacements..."
375+
)
376+
366377
phonon.forces = [
367378
self._calc_forces(supercell)
368379
for supercell in disp_supercells

janus_core/cli/phonons.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def phonons(
238238
"write_results": True,
239239
"write_full": write_full,
240240
"file_prefix": file_prefix,
241+
"enable_progress_bar": True,
241242
}
242243

243244
# Initialise phonons

janus_core/helpers/utils.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,19 @@
66
from io import StringIO
77
import logging
88
from pathlib import Path
9-
from typing import Any, Literal, Optional, TextIO, get_args
9+
from typing import Any, Literal, Optional, TextIO, Union, get_args
1010

1111
from ase import Atoms
1212
from ase.io import read, write
1313
from ase.io.formats import filetype
14+
from rich.progress import (
15+
BarColumn,
16+
MofNCompleteColumn,
17+
Progress,
18+
TextColumn,
19+
TimeRemainingColumn,
20+
)
21+
from rich.style import Style
1422
from spglib import get_spacegroup
1523

1624
from janus_core.helpers.janus_types import (
@@ -674,3 +682,44 @@ def _dump_csv(
674682

675683
for cols in zip(*columns.values()):
676684
print(",".join(map(format, cols, formats)), file=file)
685+
686+
687+
def track_progress(sequence: Union[Sequence, Iterable], description: str) -> Iterable:
688+
"""
689+
Track the progress of iterating over a sequence.
690+
691+
This is done by displaying a progress bar in the console using the rich library.
692+
The function is an iterator over the sequence, updating the progress bar each
693+
iteration.
694+
695+
Parameters
696+
----------
697+
sequence : Iterable
698+
The sequence to iterate over. Must support "len".
699+
description : str
700+
The text to display to the left of the progress bar.
701+
702+
Yields
703+
------
704+
Iterable
705+
An iterable of the values in the sequence.
706+
"""
707+
text_column = TextColumn("{task.description}")
708+
bar_column = BarColumn(
709+
bar_width=None,
710+
complete_style=Style(color="#FBBB10"),
711+
finished_style=Style(color="#E38408"),
712+
)
713+
completion_column = MofNCompleteColumn()
714+
time_column = TimeRemainingColumn()
715+
progress = Progress(
716+
text_column,
717+
bar_column,
718+
completion_column,
719+
time_column,
720+
expand=True,
721+
auto_refresh=False,
722+
)
723+
724+
with progress:
725+
yield from progress.track(sequence, description=description)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ numpy = "^1.26.4"
3333
phonopy = "^2.23.1"
3434
python = "^3.9"
3535
pyyaml = "^6.0.1"
36+
rich = "^13.9.1"
3637
seekpath = "^1.9.7"
3738
spglib = "^2.3.0"
3839
torch = ">= 2.1, <= 2.2" # Range required for dgl

0 commit comments

Comments
 (0)