Source code for autogl.datasets._pyg

import os
from autogl.data.graph import GeneralStaticGraphGenerator
from autogl.data import InMemoryStaticGraphSet
from ._dataset_registry import DatasetUniversalRegistry
import torch_geometric
from torch_geometric.datasets import (
    Amazon, Coauthor, Flickr, ModelNet,
    Planetoid, PPI, QM9, Reddit, TUDataset
)


[docs]@DatasetUniversalRegistry.register_dataset("cora") class CoraDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = Planetoid(os.path.join(path, '_pyg'), "Cora") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") pyg_data = pyg_dataset[0] static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph( { 'x': pyg_data.x, 'y': pyg_data.y, 'train_mask': getattr(pyg_data, 'train_mask'), 'val_mask': getattr(pyg_data, 'val_mask'), 'test_mask': getattr(pyg_data, 'test_mask') }, pyg_data.edge_index ) super(CoraDataset, self).__init__([static_graph])
[docs]@DatasetUniversalRegistry.register_dataset("CiteSeer".lower()) class CiteSeerDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = Planetoid(os.path.join(path, '_pyg'), "CiteSeer") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") pyg_data = pyg_dataset[0] static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph( { 'x': pyg_data.x, 'y': pyg_data.y, 'train_mask': getattr(pyg_data, 'train_mask'), 'val_mask': getattr(pyg_data, 'val_mask'), 'test_mask': getattr(pyg_data, 'test_mask') }, pyg_data.edge_index ) super(CiteSeerDataset, self).__init__([static_graph])
[docs]@DatasetUniversalRegistry.register_dataset("PubMed".lower()) class PubMedDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = Planetoid(os.path.join(path, '_pyg'), "PubMed") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") pyg_data = pyg_dataset[0] static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph( { 'x': pyg_data.x, 'y': pyg_data.y, 'train_mask': getattr(pyg_data, 'train_mask'), 'val_mask': getattr(pyg_data, 'val_mask'), 'test_mask': getattr(pyg_data, 'test_mask') }, pyg_data.edge_index ) super(PubMedDataset, self).__init__([static_graph])
[docs]@DatasetUniversalRegistry.register_dataset("flickr") class FlickrDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = Flickr(os.path.join(path, '_pyg')) if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") pyg_data = pyg_dataset[0] static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph( { 'x': pyg_data.x, 'y': pyg_data.y, 'train_mask': getattr(pyg_data, 'train_mask'), 'val_mask': getattr(pyg_data, 'val_mask'), 'test_mask': getattr(pyg_data, 'test_mask') }, pyg_data.edge_index ) super(FlickrDataset, self).__init__([static_graph])
[docs]@DatasetUniversalRegistry.register_dataset("reddit") class RedditDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = Reddit(os.path.join(path, '_pyg')) if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") pyg_data = pyg_dataset[0] static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph( { 'x': pyg_data.x, 'y': pyg_data.y, 'train_mask': getattr(pyg_data, 'train_mask'), 'val_mask': getattr(pyg_data, 'val_mask'), 'test_mask': getattr(pyg_data, 'test_mask') }, pyg_data.edge_index ) super(RedditDataset, self).__init__([static_graph])
[docs]@DatasetUniversalRegistry.register_dataset("amazon_computers") class AmazonComputersDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = Amazon(os.path.join(path, '_pyg'), "Computers") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") pyg_data = pyg_dataset[0] static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'x': pyg_data.x, 'y': pyg_data.y}, pyg_data.edge_index ) super(AmazonComputersDataset, self).__init__([static_graph])
[docs]@DatasetUniversalRegistry.register_dataset("amazon_photo") class AmazonPhotoDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = Amazon(os.path.join(path, '_pyg'), "Photo") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") pyg_data = pyg_dataset[0] static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'x': pyg_data.x, 'y': pyg_data.y}, pyg_data.edge_index ) super(AmazonPhotoDataset, self).__init__([static_graph])
[docs]@DatasetUniversalRegistry.register_dataset("coauthor_physics") class CoauthorPhysicsDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = Coauthor(os.path.join(path, '_pyg'), "Physics") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") pyg_data = pyg_dataset[0] static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'x': pyg_data.x, 'y': pyg_data.y}, pyg_data.edge_index ) super(CoauthorPhysicsDataset, self).__init__([static_graph])
[docs]@DatasetUniversalRegistry.register_dataset("coauthor_cs") class CoauthorCSDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = Coauthor(os.path.join(path, '_pyg'), "CS") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") pyg_data = pyg_dataset[0] static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'x': pyg_data.x, 'y': pyg_data.y}, pyg_data.edge_index ) super(CoauthorCSDataset, self).__init__([static_graph])
[docs]@DatasetUniversalRegistry.register_dataset("ppi") class PPIDataset(InMemoryStaticGraphSet): def __init__(self, path: str): train_dataset = PPI(os.path.join(path, '_pyg'), 'train') if hasattr(train_dataset, "__data_list__"): delattr(train_dataset, "__data_list__") if hasattr(train_dataset, "_data_list"): delattr(train_dataset, "_data_list") val_dataset = PPI(os.path.join(path, '_pyg'), 'val') if hasattr(val_dataset, "__data_list__"): delattr(val_dataset, "__data_list__") if hasattr(val_dataset, "_data_list"): delattr(val_dataset, "_data_list") test_dataset = PPI(os.path.join(path, '_pyg'), 'test') if hasattr(test_dataset, "__data_list__"): delattr(test_dataset, "__data_list__") if hasattr(test_dataset, "_data_list"): delattr(test_dataset, "_data_list") train_index = range(len(train_dataset)) val_index = range(len(train_dataset), len(train_dataset) + len(val_dataset)) test_index = range( len(train_dataset) + len(val_dataset), len(train_dataset) + len(val_dataset) + len(test_dataset) ) super(PPIDataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'x': data.x, 'y': data.y}, data.edge_index ) for data in train_dataset ] + [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'x': data.x, 'y': data.y}, data.edge_index ) for data in val_dataset ] + [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'x': data.x, 'y': data.y}, data.edge_index ) for data in test_dataset ], train_index, val_index, test_index )
[docs]@DatasetUniversalRegistry.register_dataset("qm9") class QM9Dataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = QM9(os.path.join(path, '_pyg')) if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") super(QM9Dataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'x': data.x, 'pos': data.pos, 'z': data.z}, data.edge_index, edges_data={'edge_attr': data.edge_attr}, graph_data={'idx': data.idx, 'y': data.y} ) for data in pyg_dataset ] )
[docs]@DatasetUniversalRegistry.register_dataset("mutag") class MUTAGDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "MUTAG") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") super(MUTAGDataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'x': pyg_data.x}, pyg_data.edge_index, edges_data={'edge_attr': pyg_data.edge_attr}, graph_data={'y': pyg_data.y} ) for pyg_data in pyg_dataset ] )
[docs]@DatasetUniversalRegistry.register_dataset("enzymes") class ENZYMESDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "ENZYMES") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") super(ENZYMESDataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'x': pyg_data.x}, pyg_data.edge_index, graph_data={'y': pyg_data.y} ) for pyg_data in pyg_dataset ] )
[docs]@DatasetUniversalRegistry.register_dataset("imdb-b") @DatasetUniversalRegistry.register_dataset("imdb-binary") @DatasetUniversalRegistry.register_dataset("IMDBb".upper()) @DatasetUniversalRegistry.register_dataset("IMDBBinary".upper()) @DatasetUniversalRegistry.register_dataset("IMDBBinary") class IMDBBinaryDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "IMDB-BINARY") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") super(IMDBBinaryDataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {}, pyg_data.edge_index, graph_data={'y': pyg_data.y} ) for pyg_data in pyg_dataset ] )
[docs]@DatasetUniversalRegistry.register_dataset("imdb-m") @DatasetUniversalRegistry.register_dataset("imdb-multi") @DatasetUniversalRegistry.register_dataset("IMDBm".upper()) @DatasetUniversalRegistry.register_dataset("IMDBMulti".upper()) @DatasetUniversalRegistry.register_dataset("IMDBMulti") class IMDBMultiDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "IMDB-MULTI") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") super(IMDBMultiDataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {}, pyg_data.edge_index, graph_data={'y': pyg_data.y} ) for pyg_data in pyg_dataset ] )
[docs]@DatasetUniversalRegistry.register_dataset("reddit-b") @DatasetUniversalRegistry.register_dataset("reddit-binary") @DatasetUniversalRegistry.register_dataset("RedditB".upper()) @DatasetUniversalRegistry.register_dataset("RedditBinary".upper()) class RedditBinaryDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "REDDIT-BINARY") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") super(RedditBinaryDataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {}, pyg_data.edge_index, graph_data={'y': pyg_data.y} ) for pyg_data in pyg_dataset ] )
[docs]@DatasetUniversalRegistry.register_dataset("reddit-multi-5k") @DatasetUniversalRegistry.register_dataset("RedditMulti5K".upper()) class REDDITMulti5KDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "REDDIT-MULTI-5K") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") super(REDDITMulti5KDataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {}, pyg_data.edge_index, graph_data={'y': pyg_data.y} ) for pyg_data in pyg_dataset ] )
[docs]@DatasetUniversalRegistry.register_dataset("reddit-multi-12k") @DatasetUniversalRegistry.register_dataset("RedditMulti12K".upper()) class REDDITMulti12KDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "REDDIT-MULTI-12K") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") super(REDDITMulti12KDataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {}, pyg_data.edge_index, graph_data={'y': pyg_data.y} ) for pyg_data in pyg_dataset ] )
[docs]@DatasetUniversalRegistry.register_dataset("collab") class COLLABDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "COLLAB") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") super(COLLABDataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {}, pyg_data.edge_index, graph_data={'y': pyg_data.y} ) for pyg_data in pyg_dataset ] )
[docs]@DatasetUniversalRegistry.register_dataset("proteins") class ProteinsDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "PROTEINS") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") super(ProteinsDataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'x': pyg_data.x}, pyg_data.edge_index, graph_data={'y': pyg_data.y} ) for pyg_data in pyg_dataset ] )
[docs]@DatasetUniversalRegistry.register_dataset("ptc-mr") class PTCMRDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "PTC_MR") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") super(PTCMRDataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'x': pyg_data.x}, pyg_data.edge_index, edges_data={'edge_attr': pyg_data.edge_attr}, graph_data={'y': pyg_data.y} ) for pyg_data in pyg_dataset ] )
[docs]@DatasetUniversalRegistry.register_dataset("nci1") class NCI1Dataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "NCI1") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") super(NCI1Dataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'x': pyg_data.x}, pyg_data.edge_index, graph_data={'y': pyg_data.y} ) for pyg_data in pyg_dataset ] )
[docs]@DatasetUniversalRegistry.register_dataset("nci109") class NCI109Dataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "NCI109") if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") super(NCI109Dataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'x': pyg_data.x}, pyg_data.edge_index, graph_data={'y': pyg_data.y} ) for pyg_data in pyg_dataset ] )
[docs]@DatasetUniversalRegistry.register_dataset("ModelNet10Training") class ModelNet10TrainingDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = ModelNet( os.path.join(path, '_pyg'), '10', True, pre_transform=torch_geometric.transforms.FaceToEdge() ) if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") super(ModelNet10TrainingDataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'pos': pyg_data.pos}, pyg_data.edge_index, graph_data={'y': pyg_data.y} ) for pyg_data in pyg_dataset ] )
[docs]@DatasetUniversalRegistry.register_dataset("ModelNet10Test") class ModelNet10TestDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = ModelNet( os.path.join(path, '_pyg'), '10', False, pre_transform=torch_geometric.transforms.FaceToEdge() ) if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") super(ModelNet10TestDataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'pos': pyg_data.pos}, pyg_data.edge_index, graph_data={'y': pyg_data.y} ) for pyg_data in pyg_dataset ] )
[docs]@DatasetUniversalRegistry.register_dataset("ModelNet40Training") class ModelNet40TrainingDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = ModelNet( os.path.join(path, '_pyg'), '40', True, pre_transform=torch_geometric.transforms.FaceToEdge() ) if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") super(ModelNet40TrainingDataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'pos': pyg_data.pos}, pyg_data.edge_index, graph_data={'y': pyg_data.y} ) for pyg_data in pyg_dataset ] )
[docs]@DatasetUniversalRegistry.register_dataset("ModelNet40Test") class ModelNet40TestDataset(InMemoryStaticGraphSet): def __init__(self, path: str): pyg_dataset = ModelNet( os.path.join(path, '_pyg'), '40', False, pre_transform=torch_geometric.transforms.FaceToEdge() ) if hasattr(pyg_dataset, "__data_list__"): delattr(pyg_dataset, "__data_list__") if hasattr(pyg_dataset, "_data_list"): delattr(pyg_dataset, "_data_list") super(ModelNet40TestDataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'pos': pyg_data.pos}, pyg_data.edge_index, graph_data={'y': pyg_data.y} ) for pyg_data in pyg_dataset ] )