File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -1076,7 +1076,7 @@ def collate_samples_into_batch_schema(samples: list[DataSampleSchema]) -> object
1076
1076
Instance of a ``BatchSchema`` object. This is not explicitly annotated
1077
1077
since the model/class is defined dynamically based off incoming data.
1078
1078
"""
1079
- ref_schema = samples [0 ].schema ()
1079
+ ref_schema = samples [0 ].model_json_schema ()
1080
1080
# initial keys are going to hold the main structure of the schema
1081
1081
schema_to_generate = {
1082
1082
"num_atoms" : (NDArray [Shape ["*" ], int ] | torch .LongTensor , ...),
@@ -1103,7 +1103,7 @@ def collate_samples_into_batch_schema(samples: list[DataSampleSchema]) -> object
1103
1103
schema_to_generate [key ] = (type (data ), ...)
1104
1104
collected_data [key ] = data
1105
1105
collected_data ["num_edges" ] = _concatenate_data_list (
1106
- [sample .graph .batch_num_edges ( ) for sample in samples ]
1106
+ [sample .graph .edge_index . size ( - 1 ) for sample in samples ]
1107
1107
).long ()
1108
1108
else :
1109
1109
from dgl import DGLGraph , batch
You can’t perform that action at this time.
0 commit comments