You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import pandas as pd
import torch
from torch_geometric.data import Data
import os.path as osp
import numpy as np
from ogb.io.read_graph_raw import read_csv_graph_raw, read_csv_heterograph_raw, read_binary_graph_raw, read_binary_heterograph_raw
from tqdm.auto import tqdm
def process_one_graph(graph, additional_node_files, additional_edge_files):
g = Data()
g.num_nodes = graph['num_nodes']
g.edge_index = torch.from_numpy(graph['edge_index'])
del graph['num_nodes']
del graph['edge_index']
if graph['edge_feat'] is not None:
g.edge_attr = torch.from_numpy(graph['edge_feat'])
del graph['edge_feat']
if graph['node_feat'] is not None:
g.x = torch.from_numpy(graph['node_feat'])
del graph['node_feat']
for key in additional_node_files:
g[key] = torch.from_numpy(graph[key])
del graph[key]
for key in additional_edge_files:
g[key] = torch.from_numpy(graph[key])
del graph[key]
return g
def process_one_heterograph(graph, additional_node_files, additional_edge_files):
g = Data()
g.__num_nodes__ = graph['num_nodes_dict']
g.num_nodes_dict = graph['num_nodes_dict']
# add edge connectivity
g.edge_index_dict = {}
for triplet, edge_index in graph['edge_index_dict'].items():
g.edge_index_dict[triplet] = torch.from_numpy(edge_index)
del graph['edge_index_dict']
if graph['edge_feat_dict'] is not None:
g.edge_attr_dict = {}
for triplet in graph['edge_feat_dict'].keys():
g.edge_attr_dict[triplet] = torch.from_numpy(graph['edge_feat_dict'][triplet])
del graph['edge_feat_dict']
if graph['node_feat_dict'] is not None:
g.x_dict = {}
for nodetype in graph['node_feat_dict'].keys():
g.x_dict[nodetype] = torch.from_numpy(graph['node_feat_dict'][nodetype])
del graph['node_feat_dict']
for key in additional_node_files:
g[key] = {}
for nodetype in graph[key].keys():
g[key][nodetype] = torch.from_numpy(graph[key][nodetype])
del graph[key]
for key in additional_edge_files:
g[key] = {}
for triplet in graph[key].keys():
g[key][triplet] = torch.from_numpy(graph[key][triplet])
del graph[key]
return g
def process_graphs_in_parallel(graph_list, additional_node_files, additional_edge_files, num_workers=10, hetero_flag=False):
pyg_graph_list = []
if hetero_flag:
process_func = process_one_heterograph
else:
process_func = process_one_graph
with ProcessPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(process_func, graph, additional_node_files, additional_edge_files) for graph in graph_list]
for future in tqdm(futures, desc="Processing graphs in parallel", total=len(futures)):
pyg_graph_list.append(future.result())
return pyg_graph_list
def read_graph_pyg(raw_dir, add_inverse_edge = False, additional_node_files = [], additional_edge_files = [], binary = False):
if binary:
# npz
graph_list = read_binary_graph_raw(raw_dir, add_inverse_edge)
else:
# csv
graph_list = read_csv_graph_raw(
raw_dir, add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files)
pyg_graph_list = []
print('Converting graphs into PyG objects...')
print(f'The total length of graph_list is: {len(graph_list)}')
pyg_graph_list = process_graphs_in_parallel(
graph_list, additional_node_files, additional_edge_files, hetero_flag=False)
return pyg_graph_list
def read_heterograph_pyg(raw_dir, add_inverse_edge = False, additional_node_files = [], additional_edge_files = [], binary = False):
if binary:
# npz
graph_list = read_binary_heterograph_raw(raw_dir, add_inverse_edge)
else:
# csv
graph_list = read_csv_heterograph_raw(
raw_dir, add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files)
print('Converting graphs into PyG objects...')
print(f'The total length of graph_list is: {len(graph_list)}')
pyg_graph_list = process_graphs_in_parallel(
graph_list, additional_node_files, additional_edge_files, hetero_flag=True)
return pyg_graph_list
if __name__ == '__main__':
pass
The text was updated successfully, but these errors were encountered:
brysonwx
changed the title
support read_graph_pyg in parallel
support read_graph_pyg in parallel for ogb/io/read_graph_pyg.py
Jan 8, 2025
The text was updated successfully, but these errors were encountered: