import typing as _typing
_D = _typing.TypeVar('_D')
class _Schema(_typing.MutableMapping[str, _typing.Any]):
def __setitem__(self, key: str, value: _typing.Any) -> None:
self.__data[key] = value
def __delitem__(self, key: str) -> None:
del self.__data[key]
def __getitem__(self, key: str) -> _typing.Any:
return self.__data[key]
def __len__(self) -> int:
return len(self.__data)
def __iter__(self) -> _typing.Iterator[str]:
return iter(self.__data)
def __init__(self):
self.__data: _typing.MutableMapping[str, _typing.Any] = {}
self.__meta_paths: _typing.Optional[
_typing.Iterable[_typing.Iterable[str]]
] = None
@property
def meta_paths(self) -> _typing.Optional[
_typing.Iterable[_typing.Iterable[str]]
]:
return self.__meta_paths
@meta_paths.setter
def meta_paths(
self, meta_paths: _typing.Optional[
_typing.Iterable[_typing.Iterable[str]]
]
):
self.__meta_paths = meta_paths
[docs]class Dataset(_typing.Iterable[_D], _typing.Sized):
def __len__(self) -> int:
raise NotImplementedError
def __iter__(self) -> _typing.Iterator[_D]:
raise NotImplementedError
def __getitem__(self, index: int) -> _D:
raise NotImplementedError
def __setitem__(self, index: int, data: _D):
raise NotImplementedError
@property
def train_split(self) -> _typing.Optional[_typing.Iterable[_D]]:
raise NotImplementedError
@property
def val_split(self) -> _typing.Optional[_typing.Iterable[_D]]:
raise NotImplementedError
@property
def test_split(self) -> _typing.Optional[_typing.Iterable[_D]]:
raise NotImplementedError
@property
def train_index(self) -> _typing.Optional[_typing.AbstractSet[int]]:
raise NotImplementedError
@property
def val_index(self) -> _typing.Optional[_typing.AbstractSet[int]]:
raise NotImplementedError
@property
def test_index(self) -> _typing.Optional[_typing.AbstractSet[int]]:
raise NotImplementedError
@train_index.setter
def train_index(self, train_index: _typing.Optional[_typing.Iterable[int]]):
raise NotImplementedError
@val_index.setter
def val_index(self, val_index: _typing.Optional[_typing.Iterable[int]]):
raise NotImplementedError
@test_index.setter
def test_index(self, test_index: _typing.Optional[_typing.Iterable[int]]):
raise NotImplementedError
@property
def schema(self) -> _Schema:
raise NotImplementedError
class _FoldsContainer:
def __init__(
self,
folds: _typing.Optional[_typing.Iterable[_typing.Tuple[_typing.Sequence[int], _typing.Sequence[int]]]] = ...
):
self._folds: _typing.Optional[_typing.List[_typing.Tuple[_typing.Sequence[int], _typing.Sequence[int]]]] = (
list(folds) if isinstance(folds, _typing.Iterable) else None
)
if self._folds is not None and len(self._folds) == 0:
self._folds = None
@property
def folds(self) -> _typing.Optional[_typing.Sequence[_typing.Tuple[_typing.Sequence[int], _typing.Sequence[int]]]]:
if self._folds is not None and len(self._folds) == 0:
self._folds = None
return self._folds
@folds.setter
def folds(self, folds: _typing.Optional[_typing.Iterable[_typing.Tuple[_typing.Sequence[int], _typing.Sequence[int]]]]):
self._folds: _typing.Optional[_typing.List[_typing.Tuple[_typing.Sequence[int], _typing.Sequence[int]]]] = (
list(folds) if isinstance(folds, _typing.Iterable) else None
)
if self._folds is not None and len(self._folds) == 0:
self._folds = None
class _FoldView:
def __init__(self, folds_container: _FoldsContainer, fold_index: int):
self._folds_container: _FoldsContainer = folds_container
self._fold_index: int = fold_index
@property
def train_index(self) -> _typing.Sequence[int]:
return self._folds_container.folds[self._fold_index][0]
@property
def val_index(self) -> _typing.Sequence[int]:
return self._folds_container.folds[self._fold_index][1]
class _FoldsView(_typing.Sequence[_FoldView]):
def __init__(self, folds_container: _FoldsContainer):
self._folds_container = folds_container
def __len__(self) -> int:
return (
len(self._folds_container.folds)
if self._folds_container.folds is not None
else 0
)
def __getitem__(self, fold_index: int) -> _FoldView:
return _FoldView(self._folds_container, fold_index)
[docs]class InMemoryDataset(Dataset[_D]):
@property
def schema(self) -> _Schema:
return self.__schema
def __init__(
self, data: _typing.Iterable[_D],
train_index: _typing.Optional[_typing.Iterable[int]] = ...,
val_index: _typing.Optional[_typing.Iterable[int]] = ...,
test_index: _typing.Optional[_typing.Iterable[int]] = ...,
schema: _typing.Optional[_Schema] = ...
):
self.__data: _typing.MutableSequence[_D] = list(data)
self.__train_index: _typing.Optional[_typing.Iterable[int]] = (
train_index if isinstance(train_index, _typing.Iterable) else None
)
self.__val_index: _typing.Optional[_typing.Iterable[int]] = (
val_index if isinstance(val_index, _typing.Iterable) else None
)
self.__test_index: _typing.Optional[_typing.Iterable[int]] = (
test_index if isinstance(test_index, _typing.Iterable) else None
)
self.__schema: _Schema = schema if isinstance(schema, _Schema) else _Schema()
self.__folds_container: _FoldsContainer = _FoldsContainer()
@property
def folds(self) -> _typing.Optional[_FoldsView]:
return (
_FoldsView(self.__folds_container)
if (
self.__folds_container.folds is not None and
len(self.__folds_container.folds) > 0
)
else None
)
@folds.setter
def folds(
self,
folds: _typing.Optional[
_typing.Iterable[
_typing.Tuple[_typing.Sequence[int], _typing.Sequence[int]]
]
] = ...
):
self.__folds_container.folds = folds
def __len__(self) -> int:
return len(self.__data)
def __iter__(self) -> _typing.Iterator[_D]:
return iter(self.__data)
def __getitem__(self, index: int) -> _D:
return self.__data[index]
def __setitem__(self, index: int, data: _D):
self.__data[index] = data
@property
def train_split(self) -> _typing.Optional[_typing.Iterable[_D]]:
return (
[self.__data[i] for i in self.__train_index]
if isinstance(self.__train_index, _typing.Iterable) else None
)
@property
def val_split(self) -> _typing.Optional[_typing.Iterable[_D]]:
return (
[self.__data[i] for i in self.__val_index]
if isinstance(self.__val_index, _typing.Iterable) else None
)
@property
def test_split(self) -> _typing.Optional[_typing.Iterable[_D]]:
return (
[self.__data[i] for i in self.__test_index]
if isinstance(self.__test_index, _typing.Iterable) else None
)
@property
def train_index(self) -> _typing.Optional[_typing.AbstractSet[int]]:
return self.__train_index
@property
def val_index(self) -> _typing.Optional[_typing.AbstractSet[int]]:
return self.__val_index
@property
def test_index(self) -> _typing.Optional[_typing.AbstractSet[int]]:
return self.__test_index
@train_index.setter
def train_index(self, train_index: _typing.Optional[_typing.Iterable[int]]):
if not (train_index is None or isinstance(train_index, _typing.Iterable)):
raise TypeError
elif train_index is None:
self.__train_index: _typing.Optional[_typing.Iterable[int]] = None
elif isinstance(train_index, _typing.Iterable):
if len(list(train_index)) == 0:
self.__train_index: _typing.Optional[_typing.Iterable[int]] = None
return
if not all([isinstance(i, int) for i in train_index]):
raise TypeError
if not (0 <= min(train_index) <= max(train_index) < len(self)):
raise ValueError
self.__train_index: _typing.Optional[_typing.Iterable[int]] = train_index
@val_index.setter
def val_index(self, val_index: _typing.Optional[_typing.Iterable[int]]):
if not (val_index is None or isinstance(val_index, _typing.Iterable)):
raise TypeError
elif val_index is None:
self.__val_index: _typing.Optional[_typing.Iterable[int]] = None
elif isinstance(val_index, _typing.Iterable):
if len(list(val_index)) == 0:
self.__val_index: _typing.Optional[_typing.Iterable[int]] = None
return
if not all([isinstance(i, int) for i in val_index]):
raise TypeError
if not (0 <= min(val_index) <= max(val_index) < len(self)):
raise ValueError
self.__val_index: _typing.Optional[_typing.Iterable[int]] = val_index
@test_index.setter
def test_index(self, test_index: _typing.Optional[_typing.Iterable[int]]):
if not (test_index is None or isinstance(test_index, _typing.Iterable)):
raise TypeError
elif test_index is None:
self.__test_index: _typing.Optional[_typing.Set[int]] = None
elif isinstance(test_index, _typing.Iterable):
if len(list(test_index)) == 0:
self.__test_index: _typing.Optional[_typing.Iterable[int]] = None
return
if not all([isinstance(i, int) for i in test_index]):
raise TypeError
if not (0 <= min(test_index) <= max(test_index) < len(self)):
raise ValueError
self.__test_index: _typing.Optional[_typing.Iterable[int]] = test_index