Skip to content

Commit 32d690c

Browse files
author
Kelvin Lee
authored
Merge pull request #337 from laserkelvin/batchschema-pyg-collate-fix
`BatchSchema` collate fix for PyG graphs
2 parents a28a5f5 + cd6b530 commit 32d690c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

matsciml/datasets/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,7 +1076,7 @@ def collate_samples_into_batch_schema(samples: list[DataSampleSchema]) -> object
10761076
Instance of a ``BatchSchema`` object. This is not explicitly annotated
10771077
since the model/class is defined dynamically based off incoming data.
10781078
"""
1079-
ref_schema = samples[0].schema()
1079+
ref_schema = samples[0].model_json_schema()
10801080
# initial keys are going to hold the main structure of the schema
10811081
schema_to_generate = {
10821082
"num_atoms": (NDArray[Shape["*"], int] | torch.LongTensor, ...),
@@ -1103,7 +1103,7 @@ def collate_samples_into_batch_schema(samples: list[DataSampleSchema]) -> object
11031103
schema_to_generate[key] = (type(data), ...)
11041104
collected_data[key] = data
11051105
collected_data["num_edges"] = _concatenate_data_list(
1106-
[sample.graph.batch_num_edges() for sample in samples]
1106+
[sample.graph.edge_index.size(-1) for sample in samples]
11071107
).long()
11081108
else:
11091109
from dgl import DGLGraph, batch

0 commit comments

Comments
 (0)