Skip to content

Commit 1b044b7

Browse files
authored
Huggingface datasets with parquet files for Ace dataset (#365)
* testing huggingface datasets with parquet files for Ace dataset * cleanup * dataset filtering * basic handling for training on gitea datasets
1 parent db81f38 commit 1b044b7

File tree

2 files changed

+102
-2
lines changed

2 files changed

+102
-2
lines changed

torchmdnet/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Distributed under the MIT License.
33
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
44

5-
from .ace import Ace
5+
from .ace import Ace, AceHF
66
from .ani import ANI1, ANI1CCX, ANI1X, ANI2X
77
from .comp6 import (
88
ANIMD,
@@ -28,6 +28,7 @@
2828

2929
__all__ = [
3030
"Ace",
31+
"AceHF",
3132
"ANIMD",
3233
"ANI1",
3334
"ANI1CCX",

torchmdnet/datasets/ace.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import os
88
import torch as pt
99
from torchmdnet.datasets.memdataset import MemmappedDataset
10-
from torch_geometric.data import Data
10+
from torch_geometric.data import Data, Dataset
1111
from tqdm import tqdm
1212

1313

@@ -291,3 +291,102 @@ def sample_iter(self, mol_ids=False):
291291
data = self.pre_transform(data)
292292

293293
yield data
294+
295+
296+
def download_gitea_dataset(path, tmpdir):
297+
try:
298+
from git import Repo
299+
except ImportError:
300+
raise ImportError(
301+
"Could not import GitPython library. Please install it first with `pip install GitPython`"
302+
)
303+
304+
assert path.startswith("ssh://")
305+
306+
# Parse the gitea URL
307+
pieces = path.split("/")
308+
repo_url = "/".join(pieces[:5])
309+
user = pieces[3]
310+
repo_name = pieces[4]
311+
file_name = pieces[-1]
312+
branch = "main"
313+
commit = None
314+
if "branch" in pieces:
315+
branch = pieces[pieces.index("branch") + 1]
316+
if "commit" in pieces:
317+
commit = pieces[pieces.index("commit") + 1]
318+
319+
outdir = os.path.join(tmpdir, f"{user}_{repo_name}")
320+
if not os.path.exists(outdir):
321+
repo = Repo.clone_from(repo_url, outdir, no_checkout=True)
322+
else:
323+
repo = Repo(outdir)
324+
325+
origin = repo.remotes.origin
326+
origin.pull()
327+
if commit is not None:
328+
repo.git.checkout(commit)
329+
else:
330+
repo.git.checkout(branch)
331+
332+
return os.path.join(outdir, file_name)
333+
334+
335+
class AceHF(Dataset):
336+
def __init__(
337+
self, root="parquet", paths=None, split="train", max_gradient=None
338+
) -> None:
339+
from datasets import load_dataset
340+
import numpy as np
341+
342+
# Handle gitea parquet datasets
343+
newpaths = paths.copy()
344+
for i, path in enumerate(paths):
345+
if "gitea" in path:
346+
newpaths[i] = download_gitea_dataset(path, "/tmp")
347+
348+
self.dataset = load_dataset(root, data_files=newpaths, split=split)
349+
if max_gradient is not None:
350+
351+
def _filter(x):
352+
if np.isnan(x["forces"]).any() or np.isnan(x["formation_energy"]).any():
353+
return False
354+
return np.max(np.linalg.norm(x["forces"], axis=1)) < max_gradient
355+
356+
self.dataset = self.dataset.filter(
357+
_filter, desc="Filtering", num_proc=os.cpu_count() // 2
358+
)
359+
self.dataset = self.dataset.with_format("torch")
360+
361+
def __len__(self):
362+
return self.dataset.num_rows
363+
364+
def __getitem__(self, idx):
365+
"""Gets the data object at index :obj:`idx`.
366+
367+
The data object contains the following attributes:
368+
369+
- :obj:`z`: Atomic numbers of the atoms.
370+
- :obj:`pos`: Positions of the atoms.
371+
- :obj:`y`: Formation energy of the molecule.
372+
- :obj:`neg_dy`: Forces on the atoms.
373+
- :obj:`q`: Total charge of the molecule.
374+
- :obj:`pq`: Partial charges of the atoms.
375+
- :obj:`dp`: Dipole moment of the molecule.
376+
377+
Args:
378+
idx (int): Index of the data object.
379+
380+
Returns:
381+
:obj:`torch_geometric.data.Data`: The data object.
382+
"""
383+
data = self.dataset[int(idx)]
384+
return Data(
385+
z=data["atomic_numbers"],
386+
pos=data["positions"],
387+
y=data["formation_energy"].view(1, 1),
388+
neg_dy=data["forces"],
389+
q=sum(data["formal_charges"]),
390+
pq=data["partial_charges"],
391+
dp=data["dipole_moment"],
392+
)

0 commit comments

Comments
 (0)