Source code for autogl.data.data

import re

import torch


def index_to_mask(index, size):
    mask = torch.zeros(size, dtype=torch.bool, device=index.device)
    mask[index] = 1
    return mask


[docs]class Data(object): r"""A plain python object modeling a single graph with various (optional) attributes: Args: x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes, num_node_features]`. (default: :obj:`None`) edge_index (LongTensor, optional): Graph connectivity in COO format with shape :obj:`[2, num_edges]`. (default: :obj:`None`) edge_attr (Tensor, optional): Edge feature matrix with shape :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) y (Tensor, optional): Graph or node targets with arbitrary shape. (default: :obj:`None`) pos (Tensor, optional): Node position matrix with shape :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) The data object is not restricted to these attributes and can be extented by any other additional data. """ def __init__(self, x=None, edge_index=None, edge_attr=None, y=None, pos=None): self.x = x self.edge_index = edge_index self.edge_attr = edge_attr self.y = y self.pos = pos
[docs] @staticmethod def from_dict(dictionary): r"""Creates a data object from a python dictionary.""" data = Data() for key, item in dictionary.items(): data[key] = item return data
[docs] def __getitem__(self, key): r"""Gets the data of the attribute :obj:`key`.""" return getattr(self, key)
[docs] def __setitem__(self, key, value): """Sets the attribute :obj:`key` to :obj:`value`.""" setattr(self, key, value)
@property def keys(self): r"""Returns all names of graph attributes.""" keys = [key for key in self.__dict__.keys() if self[key] is not None] keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"] return keys
[docs] def __len__(self): r"""Returns the number of all present attributes.""" return len(self.keys)
[docs] def __contains__(self, key): r"""Returns :obj:`True`, if the attribute :obj:`key` is present in the data.""" return key in self.keys
[docs] def __iter__(self): r"""Iterates over all present attributes in the data, yielding their attribute names and content.""" for key in sorted(self.keys): yield key, self[key]
[docs] def __call__(self, *keys): r"""Iterates over all attributes :obj:`*keys` in the data, yielding their attribute names and content. If :obj:`*keys` is not given this method will iterative over all present attributes.""" for key in sorted(self.keys) if not keys else keys: if self[key] is not None: yield key, self[key]
[docs] def cat_dim(self, key, value): r"""Returns the dimension in which the attribute :obj:`key` with content :obj:`value` gets concatenated when creating batches. .. note:: This method is for internal use only, and should only be overridden if the batch concatenation process is corrupted for a specific data attribute. """ # `*index*` and `*face*` should be concatenated in the last dimension, # everything else in the first dimension. return -1 if bool(re.search("(index|face)", key)) else 0
# own methods for processing
[docs] def get_label_number(self): r"""Get the number of labels in this dataset as dict.""" label_num = {} labels = self.y.unique().cpu().detach().numpy().tolist() for label in labels: label_num[label] = (self.y == label).sum().item() return label_num
[docs] def random_splits_mask(self, train_ratio, val_ratio, seed=None): r"""If the data has masks for train/val/test, return the splits with specific ratio. Parameters ---------- train_ratio : float the portion of data that used for training. val_ratio : float the portion of data that used for validation. seed : int random seed for splitting dataset. """ rs = torch.get_rng_state() rs_cuda = torch.cuda.get_rng_state() if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) perm = torch.randperm(self.num_nodes) train_index = perm[: int(self.num_nodes * train_ratio)] val_index = perm[ int(self.num_nodes * train_ratio) : int( self.num_nodes * (train_ratio + val_ratio) ) ] test_index = perm[int(self.num_nodes * (train_ratio + val_ratio)) :] self.train_mask = index_to_mask(train_index, size=self.num_nodes) self.val_mask = index_to_mask(val_index, size=self.num_nodes) self.test_mask = index_to_mask(test_index, size=self.num_nodes) torch.set_rng_state(rs) torch.cuda.set_rng_state(rs_cuda) return self
[docs] def random_splits_nodes(self, train_ratio, val_ratio, seed=None): r"""If the data uses id of nodes for train/val/test, return the splits with specific ratio. Parameters ---------- train_ratio : float the portion of data that used for training. val_ratio : float the portion of data that used for validation. seed : int random seed for splitting dataset. """ rs = torch.get_rng_state() rs_cuda = torch.cuda.get_rng_state() if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) perm = torch.randperm(self.num_nodes) self.train_node = perm[: int(self.num_nodes * train_ratio)] self.val_node = perm[ int(self.num_nodes * train_ratio) : int( self.num_nodes * (train_ratio + val_ratio) ) ] self.test_node = perm[int(self.num_nodes * (train_ratio + val_ratio)) :] self.train_target = self.y[self.train_node] self.valid_target = self.y[self.valid_node] self.test_target = self.y[self.test_node] torch.set_rng_state(rs) torch.cuda.set_rng_state(rs_cuda) return self
[docs] def random_splits_mask_class( self, num_train_per_class, num_val, num_test, seed=None ): r"""If the data has masks for train/val/test, return the splits with specific number of samples from every class for training. Parameters ---------- num_train_per_class : int the number of samples from every class used for training. num_val : int the total number of nodes that used for validation. num_test : int the total number of nodes that used for testing. seed : int random seed for splitting dataset. """ rs = torch.get_rng_state() rs_cuda = torch.cuda.get_rng_state() if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) num_classes = self.y.max().cpu().item() + 1 self.train_mask.fill_(False) for c in range(num_classes): idx = (self.y == c).nonzero().view(-1) idx = idx[torch.randperm(idx.size(0))[:num_train_per_class]] self.train_mask[idx] = True remaining = (~self.train_mask).nonzero().view(-1) remaining = remaining[torch.randperm(remaining.size(0))] self.val_mask.fill_(False) self.val_mask[remaining[:num_val]] = True self.test_mask.fill_(False) self.test_mask[remaining[num_val : num_val + num_test]] = True torch.set_rng_state(rs) torch.cuda.set_rng_state(rs_cuda) return self
[docs] def random_splits_nodes_class( self, num_train_per_class, num_val, num_test, seed=None ): r"""If the data uses id of nodes for train/val/test, return the splits with specific number of samples from every class for training. Parameters ---------- num_train_per_class : int the number of samples from every class used for training. num_val : int the total number of nodes that used for validation. num_test : int the total number of nodes that used for testing. seed : int random seed for splitting dataset. """ rs = torch.get_rng_state() rs_cuda = torch.cuda.get_rng_state() if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) num_classes = self.y.max().cpu().item() + 1 train_mask = torch.zeros( self.num_nodes, dtype=torch.bool, device=self.train_node.device ) sup = [] for c in range(num_classes): idx = (self.y == c).nonzero().view(-1) idx = idx[torch.randperm(idx.size(0))[:num_train_per_class]] sup.append(idx) train_mask[idx] = True self.train_node = torch.cat(sup) remaining = (~train_mask).nonzero().view(-1) remaining = remaining[torch.randperm(remaining.size(0))] self.val_node = remaining[:num_val] self.test_node = remaining[num_val : num_val + num_test] self.train_target = self.y[self.train_node] self.valid_target = self.y[self.valid_node] self.test_target = self.y[self.test_node] torch.set_rng_state(rs) torch.cuda.set_rng_state(rs_cuda) return self
[docs] def __inc__(self, key, value): r""" "Returns the incremental count to cumulatively increase the value of the next attribute of :obj:`key` when creating batches. .. note:: This method is for internal use only, and should only be overridden if the batch concatenation process is corrupted for a specific data attribute. """ # Only `*index*` and `*face*` should be cumulatively summed up when # creating batches. return self.num_nodes if bool(re.search("(index|face)", key)) else 0
@property def num_edges(self): r"""Returns the number of edges in the graph.""" for key, item in self("edge_index", "edge_attr"): return item.size(self.cat_dim(key, item)) return None @property def num_features(self): r"""Returns the number of features per node in the graph.""" return 1 if self.x.dim() == 1 else self.x.size(1) @property def num_nodes(self): if self.x is not None: return self.x.shape[0] return torch.max(self.edge_index) + 1
[docs] def is_coalesced(self): r"""Returns :obj:`True`, if edge indices are ordered and do not contain duplicate entries.""" row, col = self.edge_index index = self.num_nodes * row + col return row.size(0) == torch.unique(index).size(0)
[docs] def apply(self, func, *keys): r"""Applies the function :obj:`func` to all attributes :obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to all present attributes. """ for key, item in self(*keys): self[key] = func(item) return self
[docs] def contiguous(self, *keys): r"""Ensures a contiguous memory layout for all attributes :obj:`*keys`. If :obj:`*keys` is not given, all present attributes are ensured to have a contiguous memory layout.""" return self.apply(lambda x: x.contiguous(), *keys)
[docs] def to(self, device, *keys): r"""Performs tensor dtype and/or device conversion to all attributes :obj:`*keys`. If :obj:`*keys` is not given, the conversion is applied to all present attributes.""" return self.apply(lambda x: x.to(device), *keys)
def cuda(self, *keys): return self.apply(lambda x: x.cuda(), *keys) def clone(self): return Data.from_dict({k: v.clone() for k, v in self}) def __repr__(self): info = [ "{}={}".format(key, list(item.size())) for key, item in self if type(item) != list and type(item) != dict ] return "{}({})".format(self.__class__.__name__, ", ".join(info))