|
7 | 7 | import os
|
8 | 8 | import torch as pt
|
9 | 9 | from torchmdnet.datasets.memdataset import MemmappedDataset
|
10 |
| -from torch_geometric.data import Data |
| 10 | +from torch_geometric.data import Data, Dataset |
11 | 11 | from tqdm import tqdm
|
12 | 12 |
|
13 | 13 |
|
@@ -291,3 +291,102 @@ def sample_iter(self, mol_ids=False):
|
291 | 291 | data = self.pre_transform(data)
|
292 | 292 |
|
293 | 293 | yield data
|
| 294 | + |
| 295 | + |
| 296 | +def download_gitea_dataset(path, tmpdir): |
| 297 | + try: |
| 298 | + from git import Repo |
| 299 | + except ImportError: |
| 300 | + raise ImportError( |
| 301 | + "Could not import GitPython library. Please install it first with `pip install GitPython`" |
| 302 | + ) |
| 303 | + |
| 304 | + assert path.startswith("ssh://") |
| 305 | + |
| 306 | + # Parse the gitea URL |
| 307 | + pieces = path.split("/") |
| 308 | + repo_url = "/".join(pieces[:5]) |
| 309 | + user = pieces[3] |
| 310 | + repo_name = pieces[4] |
| 311 | + file_name = pieces[-1] |
| 312 | + branch = "main" |
| 313 | + commit = None |
| 314 | + if "branch" in pieces: |
| 315 | + branch = pieces[pieces.index("branch") + 1] |
| 316 | + if "commit" in pieces: |
| 317 | + commit = pieces[pieces.index("commit") + 1] |
| 318 | + |
| 319 | + outdir = os.path.join(tmpdir, f"{user}_{repo_name}") |
| 320 | + if not os.path.exists(outdir): |
| 321 | + repo = Repo.clone_from(repo_url, outdir, no_checkout=True) |
| 322 | + else: |
| 323 | + repo = Repo(outdir) |
| 324 | + |
| 325 | + origin = repo.remotes.origin |
| 326 | + origin.pull() |
| 327 | + if commit is not None: |
| 328 | + repo.git.checkout(commit) |
| 329 | + else: |
| 330 | + repo.git.checkout(branch) |
| 331 | + |
| 332 | + return os.path.join(outdir, file_name) |
| 333 | + |
| 334 | + |
| 335 | +class AceHF(Dataset): |
| 336 | + def __init__( |
| 337 | + self, root="parquet", paths=None, split="train", max_gradient=None |
| 338 | + ) -> None: |
| 339 | + from datasets import load_dataset |
| 340 | + import numpy as np |
| 341 | + |
| 342 | + # Handle gitea parquet datasets |
| 343 | + newpaths = paths.copy() |
| 344 | + for i, path in enumerate(paths): |
| 345 | + if "gitea" in path: |
| 346 | + newpaths[i] = download_gitea_dataset(path, "/tmp") |
| 347 | + |
| 348 | + self.dataset = load_dataset(root, data_files=newpaths, split=split) |
| 349 | + if max_gradient is not None: |
| 350 | + |
| 351 | + def _filter(x): |
| 352 | + if np.isnan(x["forces"]).any() or np.isnan(x["formation_energy"]).any(): |
| 353 | + return False |
| 354 | + return np.max(np.linalg.norm(x["forces"], axis=1)) < max_gradient |
| 355 | + |
| 356 | + self.dataset = self.dataset.filter( |
| 357 | + _filter, desc="Filtering", num_proc=os.cpu_count() // 2 |
| 358 | + ) |
| 359 | + self.dataset = self.dataset.with_format("torch") |
| 360 | + |
| 361 | + def __len__(self): |
| 362 | + return self.dataset.num_rows |
| 363 | + |
| 364 | + def __getitem__(self, idx): |
| 365 | + """Gets the data object at index :obj:`idx`. |
| 366 | +
|
| 367 | + The data object contains the following attributes: |
| 368 | +
|
| 369 | + - :obj:`z`: Atomic numbers of the atoms. |
| 370 | + - :obj:`pos`: Positions of the atoms. |
| 371 | + - :obj:`y`: Formation energy of the molecule. |
| 372 | + - :obj:`neg_dy`: Forces on the atoms. |
| 373 | + - :obj:`q`: Total charge of the molecule. |
| 374 | + - :obj:`pq`: Partial charges of the atoms. |
| 375 | + - :obj:`dp`: Dipole moment of the molecule. |
| 376 | +
|
| 377 | + Args: |
| 378 | + idx (int): Index of the data object. |
| 379 | +
|
| 380 | + Returns: |
| 381 | + :obj:`torch_geometric.data.Data`: The data object. |
| 382 | + """ |
| 383 | + data = self.dataset[int(idx)] |
| 384 | + return Data( |
| 385 | + z=data["atomic_numbers"], |
| 386 | + pos=data["positions"], |
| 387 | + y=data["formation_energy"].view(1, 1), |
| 388 | + neg_dy=data["forces"], |
| 389 | + q=sum(data["formal_charges"]), |
| 390 | + pq=data["partial_charges"], |
| 391 | + dp=data["dipole_moment"], |
| 392 | + ) |
0 commit comments