|
1 | 1 | """Some utility functions used acrross the repository.""" |
2 | 2 |
|
| 3 | +import itertools |
| 4 | +import operator |
3 | 5 | import threading |
4 | 6 | from enum import Enum |
5 | 7 | from typing import Any |
6 | 8 |
|
| 9 | +import ase |
| 10 | +import ase.io |
7 | 11 | import ipywidgets as ipw |
8 | | -import more_itertools as mit |
9 | 12 | import numpy as np |
10 | | -import traitlets |
| 13 | +import traitlets as tl |
11 | 14 | from aiida.plugins import DataFactory |
12 | | -from ase import Atoms |
13 | | -from ase.io import read |
14 | 15 |
|
15 | 16 | CifData = DataFactory("core.cif") # pylint: disable=invalid-name |
16 | 17 | StructureData = DataFactory("core.structure") # pylint: disable=invalid-name |
@@ -43,24 +44,40 @@ def get_ase_from_file(fname, file_format=None): # pylint: disable=redefined-bui |
43 | 44 | # store_tags parameter is useful for CIF files |
44 | 45 | # https://wiki.fysik.dtu.dk/ase/ase/io/formatoptions.html#cif |
45 | 46 | if file_format == "cif": |
46 | | - traj = read(fname, format=file_format, index=":", store_tags=True) |
| 47 | + traj = ase.io.read(fname, format=file_format, index=":", store_tags=True) |
47 | 48 | else: |
48 | | - traj = read(fname, format=file_format, index=":") |
| 49 | + traj = ase.io.read(fname, format=file_format, index=":") |
49 | 50 | if not traj: |
50 | 51 | raise ValueError(f"Could not read any information from the file {fname}") |
51 | 52 | return traj |
52 | 53 |
|
53 | 54 |
|
54 | 55 | def find_ranges(iterable): |
55 | 56 | """Yield range of consecutive numbers.""" |
56 | | - for grp in mit.consecutive_groups(iterable): |
| 57 | + for grp in _consecutive_groups(iterable): |
57 | 58 | group = list(grp) |
58 | 59 | if len(group) == 1: |
59 | 60 | yield group[0] |
60 | 61 | else: |
61 | 62 | yield group[0], group[-1] |
62 | 63 |
|
63 | 64 |
|
| 65 | +def _consecutive_groups(iterable, ordering=lambda x: x): |
| 66 | + """Yield groups of consecutive items using :func:`itertools.groupby`. |
| 67 | + The *ordering* function determines whether two items are adjacent by |
| 68 | + returning their position. |
| 69 | +
|
| 70 | + This is a vendored version of more_itertools.consecutive_groups |
| 71 | + https://more-itertools.readthedocs.io/en/v10.3.0/_modules/more_itertools/more.html#consecutive_groups |
| 72 | + Distributed under MIT license: https://more-itertools.readthedocs.io/en/v10.3.0/license.html |
| 73 | + Thank you Bo Bayles for the original implementation. <3 |
| 74 | + """ |
| 75 | + for _, g in itertools.groupby( |
| 76 | + enumerate(iterable), key=lambda x: x[0] - ordering(x[1]) |
| 77 | + ): |
| 78 | + yield map(operator.itemgetter(1), g) |
| 79 | + |
| 80 | + |
64 | 81 | def list_to_string_range(lst, shift=1): |
65 | 82 | """Converts a list like [0, 2, 3, 4] into a string like '1 3..5'. |
66 | 83 |
|
@@ -124,15 +141,15 @@ def inverse_matrix(self): |
124 | 141 | return np.linalg.inv(self.matrix) |
125 | 142 |
|
126 | 143 |
|
127 | | -class _StatusWidgetMixin(traitlets.HasTraits): |
| 144 | +class _StatusWidgetMixin(tl.HasTraits): |
128 | 145 | """Show temporary messages for example for status updates. |
129 | 146 | This is a mixin class that is meant to be part of an inheritance |
130 | 147 | tree of an actual widget with a 'value' traitlet that is used |
131 | 148 | to convey a status message. See the non-private classes below |
132 | 149 | for examples. |
133 | 150 | """ |
134 | 151 |
|
135 | | - message = traitlets.Unicode(default_value="", allow_none=True) |
| 152 | + message = tl.Unicode(default_value="", allow_none=True) |
136 | 153 | new_line = "\n" |
137 | 154 |
|
138 | 155 | def __init__(self, clear_after=3, *args, **kwargs): |
@@ -169,7 +186,7 @@ class StatusHTML(_StatusWidgetMixin, ipw.HTML): |
169 | 186 |
|
170 | 187 | # This method should be part of _StatusWidgetMixin, but that does not work |
171 | 188 | # for an unknown reason. |
172 | | - @traitlets.observe("message") |
| 189 | + @tl.observe("message") |
173 | 190 | def _observe_message(self, change): |
174 | 191 | self.show_temporary_message(change["new"]) |
175 | 192 |
|
@@ -201,7 +218,7 @@ def wrap_message(message, level=MessageLevel.INFO): |
201 | 218 | """ |
202 | 219 |
|
203 | 220 |
|
204 | | -def ase2spglib(ase_structure: Atoms) -> tuple[Any, Any, Any]: |
| 221 | +def ase2spglib(ase_structure: ase.Atoms) -> tuple[Any, Any, Any]: |
205 | 222 | """ |
206 | 223 | Convert ase Atoms instance to spglib cell in the format defined at |
207 | 224 | https://spglib.github.io/spglib/python-spglib.html#crystal-structure-cell |
|
0 commit comments