Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support read_graph_pyg in parallel for ogb/io/read_graph_pyg.py #496

Open
brysonwx opened this issue Jan 8, 2025 · 0 comments
Open

support read_graph_pyg in parallel for ogb/io/read_graph_pyg.py #496

brysonwx opened this issue Jan 8, 2025 · 0 comments

Comments

@brysonwx
Copy link

brysonwx commented Jan 8, 2025

import pandas as pd
import torch
from torch_geometric.data import Data
import os.path as osp
import numpy as np
from ogb.io.read_graph_raw import read_csv_graph_raw, read_csv_heterograph_raw, read_binary_graph_raw, read_binary_heterograph_raw
from tqdm.auto import tqdm


def process_one_graph(graph, additional_node_files, additional_edge_files):
    g = Data()
    g.num_nodes = graph['num_nodes']
    g.edge_index = torch.from_numpy(graph['edge_index'])

    del graph['num_nodes']
    del graph['edge_index']

    if graph['edge_feat'] is not None:
        g.edge_attr = torch.from_numpy(graph['edge_feat'])
        del graph['edge_feat']

    if graph['node_feat'] is not None:
        g.x = torch.from_numpy(graph['node_feat'])
        del graph['node_feat']

    for key in additional_node_files:
        g[key] = torch.from_numpy(graph[key])
        del graph[key]

    for key in additional_edge_files:
        g[key] = torch.from_numpy(graph[key])
        del graph[key]
    return g


def process_one_heterograph(graph, additional_node_files, additional_edge_files):
    g = Data()

    g.__num_nodes__ = graph['num_nodes_dict']
    g.num_nodes_dict = graph['num_nodes_dict']

    # add edge connectivity
    g.edge_index_dict = {}
    for triplet, edge_index in graph['edge_index_dict'].items():
        g.edge_index_dict[triplet] = torch.from_numpy(edge_index)

    del graph['edge_index_dict']

    if graph['edge_feat_dict'] is not None:
        g.edge_attr_dict = {}
        for triplet in graph['edge_feat_dict'].keys():
            g.edge_attr_dict[triplet] = torch.from_numpy(graph['edge_feat_dict'][triplet])

        del graph['edge_feat_dict']

    if graph['node_feat_dict'] is not None:
        g.x_dict = {}
        for nodetype in graph['node_feat_dict'].keys():
            g.x_dict[nodetype] = torch.from_numpy(graph['node_feat_dict'][nodetype])

        del graph['node_feat_dict']

    for key in additional_node_files:
        g[key] = {}
        for nodetype in graph[key].keys():
            g[key][nodetype] = torch.from_numpy(graph[key][nodetype])

        del graph[key]

    for key in additional_edge_files:
        g[key] = {}
        for triplet in graph[key].keys():
            g[key][triplet] = torch.from_numpy(graph[key][triplet])

        del graph[key]
    return g


def process_graphs_in_parallel(graph_list, additional_node_files, additional_edge_files, num_workers=10, hetero_flag=False):
    pyg_graph_list = []

    if hetero_flag:
        process_func = process_one_heterograph
    else:
        process_func = process_one_graph
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        futures = [executor.submit(process_func, graph, additional_node_files, additional_edge_files) for graph in graph_list]

        for future in tqdm(futures, desc="Processing graphs in parallel", total=len(futures)):
            pyg_graph_list.append(future.result())

    return pyg_graph_list


def read_graph_pyg(raw_dir, add_inverse_edge = False, additional_node_files = [], additional_edge_files = [], binary = False):

    if binary:
        # npz
        graph_list = read_binary_graph_raw(raw_dir, add_inverse_edge)
    else:
        # csv
        graph_list = read_csv_graph_raw(
            raw_dir, add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files)
    
    pyg_graph_list = []

    print('Converting graphs into PyG objects...')
    print(f'The total length of graph_list is: {len(graph_list)}')

    pyg_graph_list = process_graphs_in_parallel(
        graph_list, additional_node_files, additional_edge_files, hetero_flag=False)

    return pyg_graph_list


def read_heterograph_pyg(raw_dir, add_inverse_edge = False, additional_node_files = [], additional_edge_files = [], binary = False):

    if binary:
        # npz
        graph_list = read_binary_heterograph_raw(raw_dir, add_inverse_edge)
    else:
        # csv
        graph_list = read_csv_heterograph_raw(
            raw_dir, add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files)

    print('Converting graphs into PyG objects...')
    print(f'The total length of graph_list is: {len(graph_list)}')

    pyg_graph_list = process_graphs_in_parallel(
        graph_list, additional_node_files, additional_edge_files, hetero_flag=True)

    return pyg_graph_list

if __name__ == '__main__':
    pass

@brysonwx brysonwx changed the title support read_graph_pyg in parallel support read_graph_pyg in parallel for ogb/io/read_graph_pyg.py Jan 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant