import itertools
import os
import scipy.io
import torch
import typing as _typing
from autogl.data import Data, download_url, InMemoryStaticGraphSet
from autogl.data.graph import GeneralStaticGraphGenerator
from ._dataset_registry import DatasetUniversalRegistry
from ._data_source import OnlineDataSource
from .. import backend as _backend
class _MATLABMatrix(OnlineDataSource):
@property
def _raw_filenames(self) -> _typing.Iterable[str]:
splits = [self.__name]
files = ["mat"]
return [
"{}.{}".format(s, f) for s, f
in itertools.product(splits, files)
]
@property
def _processed_filenames(self) -> _typing.Iterable[str]:
return ["data.pt"]
def _fetch(self):
for name in self._raw_filenames:
download_url(self.__url + name, self._raw_directory)
def _process(self):
path = os.path.join(self._raw_directory, f"{self.__name}.mat")
mat = scipy.io.loadmat(path)
adj_matrix, group = mat["network"], mat["group"]
y = torch.from_numpy(group.todense()).to(torch.float)
row_ind, col_ind = adj_matrix.nonzero()
edge_index = torch.stack([torch.tensor(row_ind), torch.tensor(col_ind)], dim=0)
edge_attr = torch.tensor(adj_matrix[row_ind, col_ind])
data = Data(edge_index=edge_index, edge_attr=edge_attr, x=None, y=y)
torch.save(data, list(self._processed_file_paths)[0])
def __len__(self) -> int:
return 1
def __getitem__(self, index: int):
if index != 0:
raise IndexError
return self.__data
def __init__(self, path: str, name: str, url: str):
self.__name: str = name
self.__url: str = url
super(_MATLABMatrix, self).__init__(path)
self.__data = torch.load(
list(self._processed_file_paths)[0]
)
[docs]@DatasetUniversalRegistry.register_dataset("BlogCatalog".lower())
class BlogCatalogDataset(InMemoryStaticGraphSet):
def __init__(self, path: str):
filename: str = "BlogCatalog".lower()
url: str = "http://leitang.net/code/social-dimension/data/"
data = _MATLABMatrix(path, filename, url)[0]
if _backend.DependentBackend.is_dgl():
super(BlogCatalogDataset, self).__init__(
[
GeneralStaticGraphGenerator.create_homogeneous_static_graph(
{'label': data.y}, data.edge_index,
{'edge_attr': data.edge_attr}
)
]
)
elif _backend.DependentBackend.is_pyg():
super(BlogCatalogDataset, self).__init__(
[
GeneralStaticGraphGenerator.create_homogeneous_static_graph(
{'y': data.y}, data.edge_index,
{'edge_attr': data.edge_attr}
)
]
)
[docs]@DatasetUniversalRegistry.register_dataset("WikiPEDIA".lower())
class WIKIPEDIADataset(InMemoryStaticGraphSet):
def __init__(self, path: str):
filename: str = "POS"
url = "http://snap.stanford.edu/node2vec/"
data = _MATLABMatrix(path, filename, url)[0]
if _backend.DependentBackend.is_dgl():
super(WIKIPEDIADataset, self).__init__(
[
GeneralStaticGraphGenerator.create_homogeneous_static_graph(
{'label': data.y}, data.edge_index,
{'attr': data.edge_attr}
)
]
)
elif _backend.DependentBackend.is_pyg():
super(WIKIPEDIADataset, self).__init__(
[
GeneralStaticGraphGenerator.create_homogeneous_static_graph(
{'y': data.y}, data.edge_index,
{'attr': data.edge_attr}
)
]
)