Adding functional way of stacking DataPipes (#52507)
Summary:
Allows to use functional API to stack datapipes:
```python
numbers_dp = NumbersDataset(size=10).filter(filter_fn = lambda x: x % 2 == 1).map(fn = lambda x: x * 10)
```
DataPipes have to be decorated with:
```python
functional_datapipe('map')
class MapIterDataPipe(IterDataPipe[T_co]):
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52507
Reviewed By: ailzhang
Differential Revision: D26644079
Pulled By: VitalyFedyunin
fbshipit-source-id: dcf464637b4fcf9ea1eb8e84c2a0cd4dfd58b43d
diff --git a/torch/utils/data/__init__.py b/torch/utils/data/__init__.py
index cf18f2e..61e648d 100644
--- a/torch/utils/data/__init__.py
+++ b/torch/utils/data/__init__.py
@@ -2,6 +2,7 @@
from .dataset import (Dataset, IterableDataset, TensorDataset, ConcatDataset, ChainDataset, BufferedShuffleDataset,
Subset, random_split)
from .dataset import IterableDataset as IterDataPipe
+from .dataset import functional_datapipe
from .distributed import DistributedSampler
from .dataloader import DataLoader, _DatasetKind, get_worker_info
from . import datapipes
@@ -11,4 +12,4 @@
'DistributedSampler', 'Dataset', 'IterableDataset', 'TensorDataset',
'ConcatDataset', 'ChainDataset', 'BufferedShuffleDataset', 'Subset',
'random_split', 'DataLoader', '_DatasetKind', 'get_worker_info',
- 'IterDataPipe']
+ 'IterDataPipe', 'functional_datapipe']
diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py
index 7a73ef7..49af479 100644
--- a/torch/utils/data/datapipes/iter/callable.py
+++ b/torch/utils/data/datapipes/iter/callable.py
@@ -1,6 +1,6 @@
import warnings
import torch.nn as nn
-from torch.utils.data import IterDataPipe, _utils
+from torch.utils.data import IterDataPipe, _utils, functional_datapipe
from typing import Callable, Dict, Iterator, Optional, Sized, Tuple, TypeVar
T_co = TypeVar('T_co', covariant=True)
@@ -12,7 +12,7 @@
def default_fn(data):
return data
-
+@functional_datapipe('map')
class MapIterDataPipe(IterDataPipe[T_co]):
r""" :class:`MapIterDataPipe`.
diff --git a/torch/utils/data/datapipes/iter/selecting.py b/torch/utils/data/datapipes/iter/selecting.py
index 97ed5ca..9e76390 100644
--- a/torch/utils/data/datapipes/iter/selecting.py
+++ b/torch/utils/data/datapipes/iter/selecting.py
@@ -1,4 +1,4 @@
-from torch.utils.data import IterDataPipe
+from torch.utils.data import IterDataPipe, functional_datapipe
from typing import Callable, TypeVar, Iterator, Optional, Tuple, Dict
from .callable import MapIterDataPipe
@@ -6,6 +6,7 @@
T_co = TypeVar('T_co', covariant=True)
+@functional_datapipe('filter')
class FilterIterDataPipe(MapIterDataPipe[T_co]):
r""" :class:`FilterIterDataPipe`.
diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py
index 0bef57c..f847371 100644
--- a/torch/utils/data/dataset.py
+++ b/torch/utils/data/dataset.py
@@ -1,6 +1,7 @@
import bisect
import random
import warnings
+import functools
from torch._utils import _accumulate
from torch import randperm
@@ -12,6 +13,16 @@
T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
+class functional_datapipe(object):
+ def __init__(self, name):
+ self.name = name
+
+ def __call__(self, cls):
+ if not issubclass(cls, IterableDataset):
+ raise Exception('Can only decorate IterDataPipe')
+ IterableDataset.register_datapipe_as_function(self.name, cls)
+ return cls
+
class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
@@ -142,6 +153,7 @@
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
"""
+ functions = {}
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
@@ -152,6 +164,27 @@
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
+ def __getattr__(self, attribute_name):
+ if attribute_name in IterableDataset.functions:
+ function = functools.partial(IterableDataset.functions[attribute_name], self)
+ return function
+ else:
+ raise AttributeError
+
+ @classmethod
+ def register_function(cls, function_name, function):
+ IterableDataset.functions[function_name] = function
+
+ @classmethod
+ def register_datapipe_as_function(cls, function_name, cls_to_register):
+ if function_name in IterableDataset.functions:
+ raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name))
+
+ def class_function(cls, source_dp, *args, **kwargs):
+ return cls(source_dp, *args, **kwargs)
+ function = functools.partial(class_function, cls_to_register)
+ IterableDataset.functions[function_name] = function
+
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
r"""Dataset wrapping tensors.