Adding functional way of stacking DataPipes (#52507)

Summary:
Allows to use functional API to stack datapipes:
```python
numbers_dp = NumbersDataset(size=10).filter(filter_fn = lambda x: x % 2 == 1).map(fn = lambda x: x * 10)
```

DataPipes have to be decorated with:
```python
functional_datapipe('map')
class MapIterDataPipe(IterDataPipe[T_co]):
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/52507

Reviewed By: ailzhang

Differential Revision: D26644079

Pulled By: VitalyFedyunin

fbshipit-source-id: dcf464637b4fcf9ea1eb8e84c2a0cd4dfd58b43d
diff --git a/torch/utils/data/__init__.py b/torch/utils/data/__init__.py
index cf18f2e..61e648d 100644
--- a/torch/utils/data/__init__.py
+++ b/torch/utils/data/__init__.py
@@ -2,6 +2,7 @@
 from .dataset import (Dataset, IterableDataset, TensorDataset, ConcatDataset, ChainDataset, BufferedShuffleDataset,
                       Subset, random_split)
 from .dataset import IterableDataset as IterDataPipe
+from .dataset import functional_datapipe
 from .distributed import DistributedSampler
 from .dataloader import DataLoader, _DatasetKind, get_worker_info
 from . import datapipes
@@ -11,4 +12,4 @@
            'DistributedSampler', 'Dataset', 'IterableDataset', 'TensorDataset',
            'ConcatDataset', 'ChainDataset', 'BufferedShuffleDataset', 'Subset',
            'random_split', 'DataLoader', '_DatasetKind', 'get_worker_info',
-           'IterDataPipe']
+           'IterDataPipe', 'functional_datapipe']
diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py
index 7a73ef7..49af479 100644
--- a/torch/utils/data/datapipes/iter/callable.py
+++ b/torch/utils/data/datapipes/iter/callable.py
@@ -1,6 +1,6 @@
 import warnings
 import torch.nn as nn
-from torch.utils.data import IterDataPipe, _utils
+from torch.utils.data import IterDataPipe, _utils, functional_datapipe
 from typing import Callable, Dict, Iterator, Optional, Sized, Tuple, TypeVar
 
 T_co = TypeVar('T_co', covariant=True)
@@ -12,7 +12,7 @@
 def default_fn(data):
     return data
 
-
+@functional_datapipe('map')
 class MapIterDataPipe(IterDataPipe[T_co]):
     r""" :class:`MapIterDataPipe`.
 
diff --git a/torch/utils/data/datapipes/iter/selecting.py b/torch/utils/data/datapipes/iter/selecting.py
index 97ed5ca..9e76390 100644
--- a/torch/utils/data/datapipes/iter/selecting.py
+++ b/torch/utils/data/datapipes/iter/selecting.py
@@ -1,4 +1,4 @@
-from torch.utils.data import IterDataPipe
+from torch.utils.data import IterDataPipe, functional_datapipe
 from typing import Callable, TypeVar, Iterator, Optional, Tuple, Dict
 
 from .callable import MapIterDataPipe
@@ -6,6 +6,7 @@
 T_co = TypeVar('T_co', covariant=True)
 
 
+@functional_datapipe('filter')
 class FilterIterDataPipe(MapIterDataPipe[T_co]):
     r""" :class:`FilterIterDataPipe`.
 
diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py
index 0bef57c..f847371 100644
--- a/torch/utils/data/dataset.py
+++ b/torch/utils/data/dataset.py
@@ -1,6 +1,7 @@
 import bisect
 import random
 import warnings
+import functools
 
 from torch._utils import _accumulate
 from torch import randperm
@@ -12,6 +13,16 @@
 T_co = TypeVar('T_co', covariant=True)
 T = TypeVar('T')
 
+class functional_datapipe(object):
+    def __init__(self, name):
+        self.name = name
+
+    def __call__(self, cls):
+        if not issubclass(cls, IterableDataset):
+            raise Exception('Can only decorate IterDataPipe')
+        IterableDataset.register_datapipe_as_function(self.name, cls)
+        return cls
+
 
 class Dataset(Generic[T_co]):
     r"""An abstract class representing a :class:`Dataset`.
@@ -142,6 +153,7 @@
         >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
         [3, 4, 5, 6]
     """
+    functions = {}
 
     def __iter__(self) -> Iterator[T_co]:
         raise NotImplementedError
@@ -152,6 +164,27 @@
     # No `def __len__(self)` default?
     # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
 
+    def __getattr__(self, attribute_name):
+        if attribute_name in IterableDataset.functions:
+            function = functools.partial(IterableDataset.functions[attribute_name], self)
+            return function
+        else:
+            raise AttributeError
+
+    @classmethod
+    def register_function(cls, function_name, function):
+        IterableDataset.functions[function_name] = function
+
+    @classmethod
+    def register_datapipe_as_function(cls, function_name, cls_to_register):
+        if function_name in IterableDataset.functions:
+            raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name))
+
+        def class_function(cls, source_dp, *args, **kwargs):
+            return cls(source_dp, *args, **kwargs)
+        function = functools.partial(class_function, cls_to_register)
+        IterableDataset.functions[function_name] = function
+
 
 class TensorDataset(Dataset[Tuple[Tensor, ...]]):
     r"""Dataset wrapping tensors.