Source code for autogl.datasets._ogb

import numpy as np
import torch
import typing as _typing
from ogb.nodeproppred import NodePropPredDataset
from ogb.linkproppred import LinkPropPredDataset
from ogb.graphproppred import GraphPropPredDataset

from torch_sparse import SparseTensor

from autogl import backend as _backend
from autogl.data import InMemoryStaticGraphSet
from autogl.data.graph import (
    GeneralStaticGraph, GeneralStaticGraphGenerator
)
from ._dataset_registry import DatasetUniversalRegistry
from .utils import index_to_mask


class _OGBDatasetUtil:
    ...


class _OGBNDatasetUtil(_OGBDatasetUtil):
    @classmethod
    def ogbn_data_to_general_static_graph(
            cls, ogbn_data: _typing.Mapping[str, _typing.Union[np.ndarray, int]],
            nodes_label: np.ndarray = ..., nodes_label_key: str = ...,
            train_index: _typing.Optional[np.ndarray] = ...,
            val_index: _typing.Optional[np.ndarray] = ...,
            test_index: _typing.Optional[np.ndarray] = ...,
            nodes_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ...,
            edges_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ...,
            graph_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ...
    ) -> GeneralStaticGraph:
        # TODO
        edge_index = ogbn_data['edge_index']
        num_nodes = ogbn_data['num_nodes']
        edge_feat = ogbn_data['edge_feat']
        if edge_feat is not None:
            edge_feat = torch.tensor(edge_feat)
        edge_index = SparseTensor(row=torch.tensor(edge_index[0]), col=torch.tensor(edge_index[1]), value=edge_feat, sparse_sizes=(num_nodes, num_nodes))
        _, _, value = edge_index.coo()
        if value is not None:
            ogbn_data['edge_feat'] = value.cpu().detach().numpy()
        else:
            ogbn_data['edge_feat'] = edge_feat
        edge_index = edge_index.to_symmetric()
        row, col, _ = edge_index.coo()
        edge_index = np.array([row.cpu().detach().numpy(), col.cpu().detach().numpy()])
        homogeneous_static_graph: GeneralStaticGraph = (
            GeneralStaticGraphGenerator.create_homogeneous_static_graph(
                dict([
                    (target_key, torch.from_numpy(ogbn_data[source_key]))
                    for source_key, target_key in nodes_data_key_mapping.items()
                ]),
                torch.tensor(edge_index),
                dict([
                    (target_key, torch.from_numpy(ogbn_data[source_key]))
                    for source_key, target_key in edges_data_key_mapping.items()
                ]) if isinstance(edges_data_key_mapping, _typing.Mapping) else ...,
                dict([
                    (target_key, torch.from_numpy(ogbn_data[source_key]))
                    for source_key, target_key in graph_data_key_mapping.items()
                ]) if isinstance(graph_data_key_mapping, _typing.Mapping) else ...
            )
        )
        if isinstance(nodes_label, np.ndarray) and isinstance(nodes_label_key, str):
            if ' ' in nodes_label_key:
                raise ValueError("Illegal nodes label key")
            homogeneous_static_graph.nodes.data[nodes_label_key] = (
                torch.from_numpy(nodes_label.squeeze()).squeeze()
            )
        if isinstance(train_index, np.ndarray):
            homogeneous_static_graph.nodes.data['train_mask'] = index_to_mask(
                torch.from_numpy(train_index), ogbn_data['num_nodes']
            )
        if isinstance(val_index, np.ndarray):
            homogeneous_static_graph.nodes.data['val_mask'] = index_to_mask(
                torch.from_numpy(val_index), ogbn_data['num_nodes']
            )
        if isinstance(test_index, np.ndarray):
            homogeneous_static_graph.nodes.data['test_mask'] = index_to_mask(
                torch.from_numpy(test_index), ogbn_data['num_nodes']
            )
        return homogeneous_static_graph

    @classmethod
    def ogbn_dataset_to_general_static_graph(
            cls, ogbn_dataset: NodePropPredDataset,
            nodes_label_key: str,
            nodes_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ...,
            edges_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ...,
            graph_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ...
    ) -> GeneralStaticGraph:
        split_idx = ogbn_dataset.get_idx_split()
        return cls.ogbn_data_to_general_static_graph(
            ogbn_dataset[0][0],
            ogbn_dataset[0][1],
            nodes_label_key,
            split_idx["train"],
            split_idx["valid"],
            split_idx["test"],
            nodes_data_key_mapping,
            edges_data_key_mapping,
            graph_data_key_mapping
        )


[docs]@DatasetUniversalRegistry.register_dataset("ogbn-products") class OGBNProductsDataset(InMemoryStaticGraphSet): def __init__(self, path: str): ogbn_dataset = NodePropPredDataset("ogbn-products", path) if _backend.DependentBackend.is_dgl(): super(OGBNProductsDataset, self).__init__([ _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( ogbn_dataset, "label", {"node_feat": "feat"}, {"edge_feat": "edge_feat"} ) ]) elif _backend.DependentBackend.is_pyg(): super(OGBNProductsDataset, self).__init__([ _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( ogbn_dataset, "y", {"node_feat": "x"} ) ])
[docs]@DatasetUniversalRegistry.register_dataset("ogbn-proteins") class OGBNProteinsDataset(InMemoryStaticGraphSet): def __init__(self, path: str): ogbn_dataset = NodePropPredDataset("ogbn-proteins", path) if _backend.DependentBackend.is_dgl(): super(OGBNProteinsDataset, self).__init__([ _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( ogbn_dataset, "label", {"node_species": "species"}, {"edge_feat": "edge_feat"} ) ]) elif _backend.DependentBackend.is_pyg(): super(OGBNProteinsDataset, self).__init__([ _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( ogbn_dataset, "y", {"node_species": "species"}, {"edge_feat": "edge_feat"} ) ])
[docs]@DatasetUniversalRegistry.register_dataset("ogbn-arxiv") class OGBNArxivDataset(InMemoryStaticGraphSet): def __init__(self, path: str): ogbn_dataset = NodePropPredDataset("ogbn-arxiv", path) if _backend.DependentBackend.is_dgl(): super(OGBNArxivDataset, self).__init__([ _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( ogbn_dataset, "label", { "node_feat": "feat", "node_year": "year" } ) ]) elif _backend.DependentBackend.is_pyg(): super(OGBNArxivDataset, self).__init__([ _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( ogbn_dataset, "y", { "node_feat": "x", "node_year": "year" } ) ])
[docs]@DatasetUniversalRegistry.register_dataset("ogbn-papers100M") class OGBNPapers100MDataset(InMemoryStaticGraphSet): def __init__(self, path: str): ogbn_dataset = NodePropPredDataset("ogbn-papers100M", path) if _backend.DependentBackend.is_dgl(): super(OGBNPapers100MDataset, self).__init__([ _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( ogbn_dataset, "label", { "node_feat": "feat", "node_year": "year" } ) ]) elif _backend.DependentBackend.is_pyg(): super(OGBNPapers100MDataset, self).__init__([ _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( ogbn_dataset, "y", { "node_feat": "x", "node_year": "year" } ) ])
# todo: currently homogeneous dataset `ogbn-mag` NOT supported class _OGBLDatasetUtil(_OGBDatasetUtil): @classmethod def ogbl_data_to_general_static_graph( cls, ogbl_data: _typing.Mapping[str, _typing.Union[np.ndarray, int]], heterogeneous_edges: _typing.Mapping[ _typing.Tuple[str, str, str], _typing.Union[ torch.Tensor, _typing.Tuple[torch.Tensor, _typing.Optional[_typing.Mapping[str, torch.Tensor]]] ] ] = ..., nodes_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ..., graph_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ... ) -> GeneralStaticGraph: return GeneralStaticGraphGenerator.create_heterogeneous_static_graph( { '': dict([ (target_data_key, torch.from_numpy(ogbl_data[source_data_key]).squeeze()) for source_data_key, target_data_key in nodes_data_key_mapping.items() ]) }, heterogeneous_edges, dict([ (target_data_key, torch.from_numpy(ogbl_data[source_data_key]).squeeze()) for source_data_key, target_data_key in graph_data_key_mapping.items() ]) if isinstance(graph_data_key_mapping, _typing.Mapping) else ... )
[docs]@DatasetUniversalRegistry.register_dataset("ogbl-ppa") class OGBLPPADataset(InMemoryStaticGraphSet): def __init__(self, path: str): ogbl_dataset = LinkPropPredDataset("ogbl-ppa", path) edge_split = ogbl_dataset.get_edge_split() super(OGBLPPADataset, self).__init__([ _OGBLDatasetUtil.ogbl_data_to_general_static_graph( ogbl_dataset[0], { ('', '', ''): torch.from_numpy(ogbl_dataset[0]['edge_index']), ('', 'train_pos_edge', ''): torch.from_numpy(edge_split['train']['edge']), ('', 'val_pos_edge', ''): torch.from_numpy(edge_split['valid']['edge']), ('', 'val_neg_edge', ''): torch.from_numpy(edge_split['valid']['edge_neg']), ('', 'test_pos_edge', ''): torch.from_numpy(edge_split['test']['edge']), ('', 'test_neg_edge', ''): torch.from_numpy(edge_split['test']['edge_neg']) }, {'node_feat': 'feat'} if _backend.DependentBackend.is_dgl() else {'node_feat': 'x'} ) ])
[docs]@DatasetUniversalRegistry.register_dataset("ogbl-collab") class OGBLCOLLABDataset(InMemoryStaticGraphSet): def __init__(self, path: str): ogbl_dataset = LinkPropPredDataset("ogbl-collab", path) edge_split = ogbl_dataset.get_edge_split() super(OGBLCOLLABDataset, self).__init__([ _OGBLDatasetUtil.ogbl_data_to_general_static_graph( ogbl_dataset[0], { ('', '', ''): torch.from_numpy(ogbl_dataset[0]['edge_index']), ('', 'train_pos_edge', ''): ( torch.from_numpy(edge_split['train']['edge']), { 'weight': torch.from_numpy(edge_split['train']['weight']), 'year': torch.from_numpy(edge_split['train']['year']) } ), ('', 'val_pos_edge', ''): ( torch.from_numpy(edge_split['valid']['edge']), { 'weight': torch.from_numpy(edge_split['valid']['weight']), 'year': torch.from_numpy(edge_split['valid']['year']) } ), ('', 'val_neg_edge', ''): torch.from_numpy(edge_split['valid']['edge_neg']), ('', 'test_pos_edge', ''): ( torch.from_numpy(edge_split['test']['edge']), { 'weight': torch.from_numpy(edge_split['test']['weight']), 'year': torch.from_numpy(edge_split['test']['year']) } ), ('', 'test_neg_edge', ''): torch.from_numpy(edge_split['test']['edge_neg']) }, {'node_feat': 'feat'} if _backend.DependentBackend.is_dgl() else {'node_feat': 'x'} ) ])
[docs]@DatasetUniversalRegistry.register_dataset("ogbl-ddi") class OGBLDDIDataset(InMemoryStaticGraphSet): def __init__(self, path: str): ogbl_dataset = LinkPropPredDataset("ogbl-ddi", path) edge_split = ogbl_dataset.get_edge_split() super(OGBLDDIDataset, self).__init__([ GeneralStaticGraphGenerator.create_heterogeneous_static_graph( {'': {'_NID': torch.arange(ogbl_dataset[0]['num_nodes'])}}, { ('', '', ''): torch.from_numpy(ogbl_dataset[0]['edge_index']), ('', 'train_pos_edge', ''): torch.from_numpy(edge_split['train']['edge']), ('', 'val_pos_edge', ''): torch.from_numpy(edge_split['valid']['edge']), ('', 'val_neg_edge', ''): torch.from_numpy(edge_split['valid']['edge_neg']), ('', 'test_pos_edge', ''): torch.from_numpy(edge_split['test']['edge']), ('', 'test_neg_edge', ''): torch.from_numpy(edge_split['test']['edge_neg']) } ) ])
[docs]@DatasetUniversalRegistry.register_dataset("ogbl-citation") @DatasetUniversalRegistry.register_dataset("ogbl-citation2") class OGBLCitation2Dataset(InMemoryStaticGraphSet): def __init__(self, path: str): ogbl_dataset = LinkPropPredDataset("ogbl-citation2", path) edge_split = ogbl_dataset.get_edge_split() super(OGBLCitation2Dataset, self).__init__([ _OGBLDatasetUtil.ogbl_data_to_general_static_graph( ogbl_dataset[0], { ('', '', ''): torch.from_numpy(ogbl_dataset[0]['edge_index']), ('', 'train_pos_edge', ''): torch.from_numpy(edge_split['train']['edge']), ('', 'val_pos_edge', ''): torch.from_numpy(edge_split['valid']['edge']), ('', 'val_neg_edge', ''): torch.from_numpy(edge_split['valid']['edge_neg']), ('', 'test_pos_edge', ''): torch.from_numpy(edge_split['test']['edge']), ('', 'test_neg_edge', ''): torch.from_numpy(edge_split['test']['edge_neg']) }, ( {'node_feat': 'feat', 'node_year': 'year'} if _backend.DependentBackend.is_dgl() else {'node_feat': 'x', 'node_year': 'year'} ) ) ])
# todo: currently homogeneous dataset `ogbl-wikikg2` and `ogbl-biokg` NOT supported class _OGBGDatasetUtil: ...
[docs]@DatasetUniversalRegistry.register_dataset("ogbg-molhiv") class OGBGMOLHIVDataset(InMemoryStaticGraphSet): def __init__(self, path: str): ogbl_dataset = GraphPropPredDataset("ogbg-molhiv", path) idx_split: _typing.Mapping[str, np.ndarray] = ogbl_dataset.get_idx_split() train_index: _typing.Any = idx_split['train'].tolist() test_index: _typing.Any = idx_split['test'].tolist() val_index: _typing.Any = idx_split['valid'].tolist() super(OGBGMOLHIVDataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( ( {"feat": torch.from_numpy(data['node_feat'])} if _backend.DependentBackend.is_dgl() else {"x": torch.from_numpy(data['node_feat'])} ), torch.from_numpy(data['edge_index']), {'edge_feat': torch.from_numpy(data['edge_feat'])}, ( {'label': torch.from_numpy(label)} if _backend.DependentBackend.is_dgl() else {'y': torch.from_numpy(label)} ) ) for data, label in ogbl_dataset ], train_index, val_index, test_index )
[docs]@DatasetUniversalRegistry.register_dataset("ogbg-molpcba") class OGBGMOLPCBADataset(InMemoryStaticGraphSet): def __init__(self, path: str): ogbl_dataset = GraphPropPredDataset("ogbg-molhiv", path) idx_split: _typing.Mapping[str, np.ndarray] = ogbl_dataset.get_idx_split() train_index: _typing.Any = idx_split['train'].tolist() test_index: _typing.Any = idx_split['test'].tolist() val_index: _typing.Any = idx_split['valid'].tolist() super(OGBGMOLPCBADataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( ( {"feat": torch.from_numpy(data['node_feat'])} if _backend.DependentBackend.is_dgl() else {"x": torch.from_numpy(data['node_feat'])} ), torch.from_numpy(data['edge_index']), {'edge_feat': torch.from_numpy(data['edge_feat'])}, ( {'label': torch.from_numpy(label)} if _backend.DependentBackend.is_dgl() else {'y': torch.from_numpy(label)} ) ) for data, label in ogbl_dataset ], train_index, val_index, test_index )
[docs]@DatasetUniversalRegistry.register_dataset("ogbg-ppa") class OGBGPPADataset(InMemoryStaticGraphSet): def __init__(self, path: str): ogbl_dataset = GraphPropPredDataset("ogbg-molhiv", path) idx_split: _typing.Mapping[str, np.ndarray] = ogbl_dataset.get_idx_split() train_index: _typing.Any = idx_split['train'].tolist() test_index: _typing.Any = idx_split['test'].tolist() val_index: _typing.Any = idx_split['valid'].tolist() super(OGBGPPADataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'_NID': torch.arange(data['num_nodes'])}, torch.from_numpy(data['edge_index']), {'edge_feat': torch.from_numpy(data['edge_feat'])}, ( {'label': torch.from_numpy(label)} if _backend.DependentBackend.is_dgl() else {'y': torch.from_numpy(label)} ) ) for data, label in ogbl_dataset ], train_index, val_index, test_index )
[docs]@DatasetUniversalRegistry.register_dataset("ogbg-code") @DatasetUniversalRegistry.register_dataset("ogbg-code2") class OGBGCode2Dataset(InMemoryStaticGraphSet): def __init__(self, path: str): ogbl_dataset = GraphPropPredDataset("ogbg-molhiv", path) idx_split: _typing.Mapping[str, np.ndarray] = ogbl_dataset.get_idx_split() train_index: _typing.Any = idx_split['train'].tolist() test_index: _typing.Any = idx_split['test'].tolist() val_index: _typing.Any = idx_split['valid'].tolist() super(OGBGCode2Dataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( ( { "feat": torch.from_numpy(data['node_feat']), "node_is_attributed": torch.from_numpy(data["node_is_attributed"]), "node_dfs_order": torch.from_numpy(data["node_dfs_order"]), "node_depth": torch.from_numpy(data["node_depth"]) } if _backend.DependentBackend.is_dgl() else { "x": torch.from_numpy(data['node_feat']), "node_is_attributed": torch.from_numpy(data["node_is_attributed"]), "node_dfs_order": torch.from_numpy(data["node_dfs_order"]), "node_depth": torch.from_numpy(data["node_depth"]) } ), torch.from_numpy(data['edge_index']) ) for data, label in ogbl_dataset ], train_index, val_index, test_index )