Source code for autogl.datasets._matlab_matrix

import itertools
import os

import scipy.io
import torch
import typing as _typing

from autogl.data import Data, download_url, InMemoryStaticGraphSet
from autogl.data.graph import GeneralStaticGraphGenerator
from ._dataset_registry import DatasetUniversalRegistry
from ._data_source import OnlineDataSource
from .. import backend as _backend


class _MATLABMatrix(OnlineDataSource):
    @property
    def _raw_filenames(self) -> _typing.Iterable[str]:
        splits = [self.__name]
        files = ["mat"]
        return [
            "{}.{}".format(s, f) for s, f
            in itertools.product(splits, files)
        ]

    @property
    def _processed_filenames(self) -> _typing.Iterable[str]:
        return ["data.pt"]

    def _fetch(self):
        for name in self._raw_filenames:
            download_url(self.__url + name, self._raw_directory)

    def _process(self):
        path = os.path.join(self._raw_directory, f"{self.__name}.mat")
        mat = scipy.io.loadmat(path)
        adj_matrix, group = mat["network"], mat["group"]

        y = torch.from_numpy(group.todense()).to(torch.float)

        row_ind, col_ind = adj_matrix.nonzero()
        edge_index = torch.stack([torch.tensor(row_ind), torch.tensor(col_ind)], dim=0)
        edge_attr = torch.tensor(adj_matrix[row_ind, col_ind])
        data = Data(edge_index=edge_index, edge_attr=edge_attr, x=None, y=y)
        torch.save(data, list(self._processed_file_paths)[0])

    def __len__(self) -> int:
        return 1

    def __getitem__(self, index: int):
        if index != 0:
            raise IndexError
        return self.__data

    def __init__(self, path: str, name: str, url: str):
        self.__name: str = name
        self.__url: str = url
        super(_MATLABMatrix, self).__init__(path)
        self.__data = torch.load(
            list(self._processed_file_paths)[0]
        )


[docs]@DatasetUniversalRegistry.register_dataset("BlogCatalog".lower()) class BlogCatalogDataset(InMemoryStaticGraphSet): def __init__(self, path: str): filename: str = "BlogCatalog".lower() url: str = "http://leitang.net/code/social-dimension/data/" data = _MATLABMatrix(path, filename, url)[0] if _backend.DependentBackend.is_dgl(): super(BlogCatalogDataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'label': data.y}, data.edge_index, {'edge_attr': data.edge_attr} ) ] ) elif _backend.DependentBackend.is_pyg(): super(BlogCatalogDataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'y': data.y}, data.edge_index, {'edge_attr': data.edge_attr} ) ] )
[docs]@DatasetUniversalRegistry.register_dataset("WikiPEDIA".lower()) class WIKIPEDIADataset(InMemoryStaticGraphSet): def __init__(self, path: str): filename: str = "POS" url = "http://snap.stanford.edu/node2vec/" data = _MATLABMatrix(path, filename, url)[0] if _backend.DependentBackend.is_dgl(): super(WIKIPEDIADataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'label': data.y}, data.edge_index, {'attr': data.edge_attr} ) ] ) elif _backend.DependentBackend.is_pyg(): super(WIKIPEDIADataset, self).__init__( [ GeneralStaticGraphGenerator.create_homogeneous_static_graph( {'y': data.y}, data.edge_index, {'attr': data.edge_attr} ) ] )