[WIP] Example of DataPipes and DataFrames integration (#60840)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60840

Test Plan: Imported from OSS

Reviewed By: wenleix, ejguan

Differential Revision: D29461080

Pulled By: VitalyFedyunin

fbshipit-source-id: 4909394dcd39e97ee49b699fda542b311b7e0d82
diff --git a/test/test_datapipe.py b/test/test_datapipe.py
index c82f81b..4d5bd7c 100644
--- a/test/test_datapipe.py
+++ b/test/test_datapipe.py
@@ -65,6 +65,13 @@
     HAS_DILL = False
 skipIfNoDill = skipIf(not HAS_DILL, "no dill")
 
+try:
+    import pandas  # type: ignore[import] # noqa: F401 F403
+    HAS_PANDAS = True
+except ImportError:
+    HAS_PANDAS = False
+skipIfNoDataFrames = skipIf(not HAS_PANDAS, "no dataframes (pandas)")
+
 T_co = TypeVar("T_co", covariant=True)
 
 
@@ -393,6 +400,64 @@
         self.assertEqual(source_numbers, list(n))
 
 
+class TestDataFramesPipes(TestCase):
+    """
+        Most of test will fail if pandas instaled, but no dill available.
+        Need to rework them to avoid multiple skips.
+    """
+    def _get_datapipe(self, range=10, dataframe_size=7):
+        return NumbersDataset(range) \
+            .map(lambda i: (i, i % 3))
+
+    def _get_dataframes_pipe(self, range=10, dataframe_size=7):
+        return NumbersDataset(range) \
+            .map(lambda i: (i, i % 3)) \
+            ._to_dataframes_pipe(
+                columns=['i', 'j'],
+                dataframe_size=dataframe_size)
+
+    @skipIfNoDataFrames
+    @skipIfNoDill  # TODO(VitalyFedyunin): Decouple tests from dill by avoiding lambdas in map
+    def test_capture(self):
+        dp_numbers = self._get_datapipe().map(lambda x: (x[0], x[1], x[1] + 3 * x[0]))
+        df_numbers = self._get_dataframes_pipe()
+        df_numbers['k'] = df_numbers['j'] + df_numbers.i * 3
+        self.assertEqual(list(dp_numbers), list(df_numbers))
+
+    @skipIfNoDataFrames
+    @skipIfNoDill
+    def test_shuffle(self):
+        #  With non-zero (but extremely low) probability (when shuffle do nothing),
+        #  this test fails, so feel free to restart
+        df_numbers = self._get_dataframes_pipe(range=1000).shuffle()
+        dp_numbers = self._get_datapipe(range=1000)
+        df_result = [tuple(item) for item in df_numbers]
+        self.assertNotEqual(list(dp_numbers), df_result)
+        self.assertEqual(list(dp_numbers), sorted(df_result))
+
+    @skipIfNoDataFrames
+    @skipIfNoDill
+    def test_batch(self):
+        df_numbers = self._get_dataframes_pipe(range=100).batch(8)
+        df_numbers_list = list(df_numbers)
+        last_batch = df_numbers_list[-1]
+        self.assertEqual(4, len(last_batch))
+        unpacked_batch = [tuple(row) for row in last_batch]
+        self.assertEqual([(96, 0), (97, 1), (98, 2), (99, 0)], unpacked_batch)
+
+    @skipIfNoDataFrames
+    @skipIfNoDill
+    def test_unbatch(self):
+        df_numbers = self._get_dataframes_pipe(range=100).batch(8).batch(3)
+        dp_numbers = self._get_datapipe(range=100)
+        self.assertEqual(list(dp_numbers), list(df_numbers.unbatch(2)))
+
+    @skipIfNoDataFrames
+    @skipIfNoDill
+    def test_filter(self):
+        df_numbers = self._get_dataframes_pipe(range=10).filter(lambda x: x.i > 5)
+        self.assertEqual([(6, 0), (7, 1), (8, 2), (9, 0)], list(df_numbers))
+
 class FileLoggerSimpleHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
     def __init__(self, *args, logfile=None, **kwargs):
         self.__loggerHandle = None
diff --git a/tools/linter/clang_tidy/run.py b/tools/linter/clang_tidy/run.py
index f84d225..9e71333 100644
--- a/tools/linter/clang_tidy/run.py
+++ b/tools/linter/clang_tidy/run.py
@@ -421,7 +421,7 @@
                 i += 1
             ranges[-1][1] = added_line_nos[-1]
 
-            files[file.path].append(*ranges)
+            files[file.path] += ranges
 
     return dict(files)
 
diff --git a/torch/utils/data/__init__.py b/torch/utils/data/__init__.py
index ac0c763..167a860 100644
--- a/torch/utils/data/__init__.py
+++ b/torch/utils/data/__init__.py
@@ -14,6 +14,7 @@
     DataChunk,
     Dataset,
     Dataset as MapDataPipe,
+    DFIterDataPipe,
     IterableDataset,
     IterableDataset as IterDataPipe,
     Subset,
diff --git a/torch/utils/data/_decorator.py b/torch/utils/data/_decorator.py
index 6e64fd5..53dce59 100644
--- a/torch/utils/data/_decorator.py
+++ b/torch/utils/data/_decorator.py
@@ -11,8 +11,14 @@
 class functional_datapipe(object):
     name: str
 
-    def __init__(self, name: str) -> None:
+    def __init__(self, name: str, enable_df_api_tracing=False) -> None:
+        """
+            Args:
+                enable_df_api_tracing - if set, any returned DataPipe would accept
+                DataFrames API in tracing mode.
+        """
         self.name = name
+        self.enable_df_api_tracing = enable_df_api_tracing
 
     def __call__(self, cls):
         if issubclass(cls, IterDataPipe):
@@ -25,9 +31,9 @@
                     not (hasattr(cls, '__self__') and
                          isinstance(cls.__self__, non_deterministic)):
                     raise TypeError('`functional_datapipe` can only decorate IterDataPipe')
-            IterDataPipe.register_datapipe_as_function(self.name, cls)
+            IterDataPipe.register_datapipe_as_function(self.name, cls, enable_df_api_tracing=self.enable_df_api_tracing)
         elif issubclass(cls, MapDataPipe):
-            MapDataPipe.register_datapipe_as_function(self.name, cls)
+            MapDataPipe.register_datapipe_as_function(self.name, cls, enable_df_api_tracing=self.enable_df_api_tracing)
 
         return cls
 
diff --git a/torch/utils/data/dataframes_pipes.ipynb b/torch/utils/data/dataframes_pipes.ipynb
new file mode 100644
index 0000000..74e36ce
--- /dev/null
+++ b/torch/utils/data/dataframes_pipes.ipynb
@@ -0,0 +1,523 @@
+{
+ "metadata": {
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.10"
+  },
+  "orig_nbformat": 2,
+  "kernelspec": {
+   "name": "python3610jvsc74a57bd0eb5e09632d6ea1cbf3eb9da7e37b7cf581db5ed13074b21cc44e159dc62acdab",
+   "display_name": "Python 3.6.10 64-bit ('dataloader': conda)"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2,
+ "cells": [
+  {
+   "source": [
+    "## \\[RFC\\] How DataFrames (DF) and DataPipes (DP) work together"
+   ],
+   "cell_type": "markdown",
+   "metadata": {}
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from importlib import reload\n",
+    "import torch\n",
+    "reload(torch)\n",
+    "from torch.utils.data import IterDataPipe"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Example IterDataPipe\n",
+    "class ExampleIterPipe(IterDataPipe):\n",
+    "    def __init__(self, range = 20):\n",
+    "        self.range = range\n",
+    "    def __iter__(self):\n",
+    "        for i in range(self.range):\n",
+    "            yield i\n",
+    "\n",
+    "def get_dataframes_pipe(range = 10, dataframe_size = 7):\n",
+    "    return ExampleIterPipe(range = range).map(lambda i: (i, i % 3))._to_dataframes_pipe(columns = ['i','j'], dataframe_size = dataframe_size)\n",
+    "\n",
+    "def get_regular_pipe(range = 10):\n",
+    "    return ExampleIterPipe(range = range).map(lambda i: (i, i % 3))\n"
+   ]
+  },
+  {
+   "source": [
+    "Doesn't matter how DF composed internally, iterator over DF Pipe gives single rows to user. This is similar to regular DataPipe."
+   ],
+   "cell_type": "markdown",
+   "metadata": {}
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": [
+      "DataFrames Pipe\n(0, 0)\n(1, 1)\n(2, 2)\n(3, 0)\n(4, 1)\n(5, 2)\n(6, 0)\n(7, 1)\n(8, 2)\n(9, 0)\nRegular DataPipe\n(0, 0)\n(1, 1)\n(2, 2)\n(3, 0)\n(4, 1)\n(5, 2)\n(6, 0)\n(7, 1)\n(8, 2)\n(9, 0)\n"
+     ]
+    }
+   ],
+   "source": [
+    "print('DataFrames Pipe')\n",
+    "dp = get_dataframes_pipe()\n",
+    "for i in dp:\n",
+    "    print(i)\n",
+    "\n",
+    "print('Regular DataPipe')\n",
+    "dp = get_regular_pipe()\n",
+    "for i in dp:\n",
+    "    print(i)"
+   ]
+  },
+  {
+   "source": [
+    "You can iterate over raw DF using `raw_iterator`"
+   ],
+   "cell_type": "markdown",
+   "metadata": {}
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": [
+      "   i  j\n0  0  0\n1  1  1\n2  2  2\n3  3  0\n4  4  1\n5  5  2\n6  6  0\n   i  j\n0  7  1\n1  8  2\n2  9  0\n"
+     ]
+    }
+   ],
+   "source": [
+    "dp = get_dataframes_pipe()\n",
+    "for i in dp.raw_iterator():\n",
+    "    print(i)"
+   ]
+  },
+  {
+   "source": [
+    "Operations over DF Pipe is captured"
+   ],
+   "cell_type": "markdown",
+   "metadata": {}
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": [
+      "var_3 = input_var_2.i * 100\nvar_4 = var_3 + input_var_2.j\nvar_5 = var_4 - 2.7\ninput_var_2[\"y\"] = var_5\n"
+     ]
+    }
+   ],
+   "source": [
+    "dp = get_dataframes_pipe(dataframe_size = 3)\n",
+    "dp['y'] = dp.i * 100 + dp.j - 2.7\n",
+    "print(dp.ops_str())\n"
+   ]
+  },
+  {
+   "source": [
+    "Captured operations executed on `__next__` calls of constructed DataPipe"
+   ],
+   "cell_type": "markdown",
+   "metadata": {}
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": [
+      "   i  j      y\n0  0  0   -2.7\n1  1  1   98.3\n2  2  2  199.3\n   i  j      y\n0  3  0  297.3\n1  4  1  398.3\n2  5  2  499.3\n   i  j      y\n0  6  0  597.3\n1  7  1  698.3\n2  8  2  799.3\n   i  j      y\n0  9  0  897.3\n"
+     ]
+    }
+   ],
+   "source": [
+    "dp = get_dataframes_pipe(dataframe_size = 3)\n",
+    "dp['y'] = dp.i * 100 + dp.j - 2.7\n",
+    "for i in dp.raw_iterator():\n",
+    "    print(i)"
+   ]
+  },
+  {
+   "source": [
+    "`shuffle` of DataFramePipe effects rows in individual manner"
+   ],
+   "cell_type": "markdown",
+   "metadata": {}
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": [
+      "Raw DataFrames iterator\n   i  j\n2  8  2\n2  2  2\n2  5  2\n   i  j\n1  4  1\n1  1  1\n0  3  0\n   i  j\n1  7  1\n0  9  0\n0  6  0\n   i  j\n0  0  0\nRegular DataFrames iterator\n(1, 1)\n(5, 2)\n(8, 2)\n(9, 0)\n(7, 1)\n(6, 0)\n(3, 0)\n(4, 1)\n(0, 0)\n(2, 2)\nRegular iterator\n(5, 2)\n(6, 0)\n(0, 0)\n(9, 0)\n(3, 0)\n(1, 1)\n(2, 2)\n(8, 2)\n(4, 1)\n(7, 1)\n"
+     ]
+    }
+   ],
+   "source": [
+    "dp = get_dataframes_pipe(dataframe_size = 3)\n",
+    "dp = dp.shuffle()\n",
+    "print('Raw DataFrames iterator')\n",
+    "for i in dp.raw_iterator():\n",
+    "    print(i)\n",
+    "\n",
+    "print('Regular DataFrames iterator')\n",
+    "for i in dp:\n",
+    "    print(i)\n",
+    "\n",
+    "\n",
+    "# this is similar to shuffle of regular DataPipe\n",
+    "dp = get_regular_pipe()\n",
+    "dp = dp.shuffle()\n",
+    "print('Regular iterator')\n",
+    "for i in dp:\n",
+    "    print(i)"
+   ]
+  },
+  {
+   "source": [
+    "You can continue mixing DF and DP operations"
+   ],
+   "cell_type": "markdown",
+   "metadata": {}
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": [
+      "    i   j          y\n0 -17 -17  -197000.0\n1 -13 -16  3813000.0\n0 -11 -17  5803000.0\n    i   j          y\n2 -12 -15  4823000.0\n1 -10 -16  6813000.0\n1 -16 -16   813000.0\n    i   j          y\n0  -8 -17  8803000.0\n2  -9 -15  7823000.0\n0 -14 -17  2803000.0\n    i   j          y\n2 -15 -15  1823000.0\n"
+     ]
+    }
+   ],
+   "source": [
+    "dp = get_dataframes_pipe(dataframe_size = 3)\n",
+    "dp['y'] = dp.i * 100 + dp.j - 2.7\n",
+    "dp = dp.shuffle()\n",
+    "dp = dp - 17\n",
+    "dp['y'] = dp.y * 10000\n",
+    "for i in dp.raw_iterator():\n",
+    "    print(i)"
+   ]
+  },
+  {
+   "source": [
+    "Batching combines everything into `list` it is possible to nest `list`s. List may have any number of DataFrames as soon as total number of rows equal to batch size."
+   ],
+   "cell_type": "markdown",
+   "metadata": {}
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": [
+      "Iterate over DataFrame batches\n[(6, 0),(0, 0)]\n[(4, 1),(1, 1)]\n[(2, 2),(9, 0)]\n[(3, 0),(5, 2)]\n[(7, 1),(8, 2)]\nIterate over regular batches\n[(1, 1),(4, 1)]\n[(2, 2),(3, 0)]\n[(6, 0),(7, 1)]\n[(8, 2),(0, 0)]\n[(5, 2),(9, 0)]\n"
+     ]
+    }
+   ],
+   "source": [
+    "dp = get_dataframes_pipe(dataframe_size = 3)\n",
+    "dp = dp.shuffle()\n",
+    "dp = dp.batch(2)\n",
+    "print(\"Iterate over DataFrame batches\")\n",
+    "for i,v in enumerate(dp):\n",
+    "    print(v)\n",
+    "\n",
+    "# this is similar to batching of regular DataPipe\n",
+    "dp = get_regular_pipe()\n",
+    "dp = dp.shuffle()\n",
+    "dp = dp.batch(2)\n",
+    "print(\"Iterate over regular batches\")\n",
+    "for i in dp:\n",
+    "    print(i)"
+   ]
+  },
+  {
+   "source": [
+    "Some details about internal storage of batched DataFrames and how they are iterated"
+   ],
+   "cell_type": "markdown",
+   "metadata": {}
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": [
+      "Type:  <class 'torch.utils.data.datapipes.iter.dataframes.DataChunkDF'>\n",
+      "As string:  [(0, 0),(3, 0)]\n",
+      "Iterated regularly:\n",
+      "-- batch start --\n",
+      "(0, 0)\n",
+      "(3, 0)\n",
+      "-- batch end --\n",
+      "Iterated in inner format (for developers):\n",
+      "-- df batch start --\n",
+      "   i  j\n",
+      "0  0  0\n",
+      "0  3  0\n",
+      "-- df batch end --\n",
+      "Type:  <class 'torch.utils.data.datapipes.iter.dataframes.DataChunkDF'>\n",
+      "As string:  [(6, 0),(1, 1)]\n",
+      "Iterated regularly:\n",
+      "-- batch start --\n",
+      "(6, 0)\n",
+      "(1, 1)\n",
+      "-- batch end --\n",
+      "Iterated in inner format (for developers):\n",
+      "-- df batch start --\n",
+      "   i  j\n",
+      "0  6  0\n",
+      "1  1  1\n",
+      "-- df batch end --\n",
+      "Type:  <class 'torch.utils.data.datapipes.iter.dataframes.DataChunkDF'>\n",
+      "As string:  [(9, 0),(4, 1)]\n",
+      "Iterated regularly:\n",
+      "-- batch start --\n",
+      "(9, 0)\n",
+      "(4, 1)\n",
+      "-- batch end --\n",
+      "Iterated in inner format (for developers):\n",
+      "-- df batch start --\n",
+      "   i  j\n",
+      "0  9  0\n",
+      "1  4  1\n",
+      "-- df batch end --\n",
+      "Type:  <class 'torch.utils.data.datapipes.iter.dataframes.DataChunkDF'>\n",
+      "As string:  [(5, 2),(2, 2)]\n",
+      "Iterated regularly:\n",
+      "-- batch start --\n",
+      "(5, 2)\n",
+      "(2, 2)\n",
+      "-- batch end --\n",
+      "Iterated in inner format (for developers):\n",
+      "-- df batch start --\n",
+      "   i  j\n",
+      "2  5  2\n",
+      "2  2  2\n",
+      "-- df batch end --\n",
+      "Type:  <class 'torch.utils.data.datapipes.iter.dataframes.DataChunkDF'>\n",
+      "As string:  [(8, 2),(7, 1)]\n",
+      "Iterated regularly:\n",
+      "-- batch start --\n",
+      "(8, 2)\n",
+      "(7, 1)\n",
+      "-- batch end --\n",
+      "Iterated in inner format (for developers):\n",
+      "-- df batch start --\n",
+      "   i  j\n",
+      "2  8  2\n",
+      "1  7  1\n",
+      "-- df batch end --\n"
+     ]
+    }
+   ],
+   "source": [
+    "dp = get_dataframes_pipe(dataframe_size = 3)\n",
+    "dp = dp.shuffle()\n",
+    "dp = dp.batch(2)\n",
+    "for i in dp:\n",
+    "    print(\"Type: \", type(i))\n",
+    "    print(\"As string: \", i)\n",
+    "    print(\"Iterated regularly:\")\n",
+    "    print('-- batch start --')\n",
+    "    for item in i:\n",
+    "        print(item)\n",
+    "    print('-- batch end --')\n",
+    "    print(\"Iterated in inner format (for developers):\")\n",
+    "    print('-- df batch start --')\n",
+    "    for item in i.raw_iterator():\n",
+    "        print(item)\n",
+    "    print('-- df batch end --')   "
+   ]
+  },
+  {
+   "source": [
+    "`concat` should work only of DF with same schema, this code should produce an error "
+   ],
+   "cell_type": "markdown",
+   "metadata": {}
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# TODO!\n",
+    "# dp0 = get_dataframes_pipe(range = 8, dataframe_size = 4)\n",
+    "# dp = get_dataframes_pipe(range = 6, dataframe_size = 3)\n",
+    "# dp['y'] = dp.i * 100 + dp.j - 2.7\n",
+    "# dp = dp.concat(dp0)\n",
+    "# for i,v in enumerate(dp.raw_iterator()):\n",
+    "#     print(v)"
+   ]
+  },
+  {
+   "source": [
+    "`unbatch` of `list` with DataFrame works similarly to regular unbatch.\n",
+    "Note: DataFrame sizes might change"
+   ],
+   "cell_type": "markdown",
+   "metadata": {}
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "error",
+     "ename": "AttributeError",
+     "evalue": "",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-12-fa80e9c68655>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;31m# Here is bug with unbatching which doesn't detect DF type.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mdp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'z'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m100\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraw_iterator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/dataset/pytorch/torch/utils/data/dataset.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, attribute_name)\u001b[0m\n\u001b[1;32m    222\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    223\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 224\u001b[0;31m             \u001b[0;32mraise\u001b[0m \u001b[0mAttributeError\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    225\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    226\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__reduce_ex__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mAttributeError\u001b[0m: "
+     ]
+    }
+   ],
+   "source": [
+    "dp = get_dataframes_pipe(range = 18, dataframe_size = 3)\n",
+    "dp['y'] = dp.i * 100 + dp.j - 2.7\n",
+    "dp = dp.batch(5).batch(3).batch(1).unbatch(unbatch_level = 3)\n",
+    "\n",
+    "# Here is bug with unbatching which doesn't detect DF type.\n",
+    "dp['z'] = dp.y - 100\n",
+    "\n",
+    "for i in dp.raw_iterator():\n",
+    "    print(i)"
+   ]
+  },
+  {
+   "source": [
+    "`map` applied to individual rows, `nesting_level` argument used to penetrate batching"
+   ],
+   "cell_type": "markdown",
+   "metadata": {}
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": [
+      "Iterate over DataFrame batches\n[(1111000, 1111000),(1112000, 1112000),(1113000, 1113000),(1114000, 1111000),(1115000, 1112000)]\n[(1116000, 1113000),(1117000, 1111000),(1118000, 1112000),(1119000, 1113000),(1120000, 1111000)]\nIterate over regular batches\n[(1111000, 0),(1112000, 1),(1113000, 2),(1114000, 0),(1115000, 1)]\n[(1116000, 2),(1117000, 0),(1118000, 1),(1119000, 2),(1120000, 0)]\n"
+     ]
+    }
+   ],
+   "source": [
+    "dp = get_dataframes_pipe(range = 10, dataframe_size = 3)\n",
+    "dp = dp.map(lambda x: x + 1111)\n",
+    "dp = dp.batch(5).map(lambda x: x * 1000, nesting_level = 1)\n",
+    "print(\"Iterate over DataFrame batches\")\n",
+    "for i in dp:\n",
+    "    print(i)\n",
+    "\n",
+    "# Similarly works on row level for classic DataPipe elements\n",
+    "dp = get_regular_pipe(range = 10)\n",
+    "dp = dp.map(lambda x: (x[0] + 1111, x[1]))\n",
+    "dp = dp.batch(5).map(lambda x: (x[0] * 1000, x[1]), nesting_level = 1)\n",
+    "print(\"Iterate over regular batches\")\n",
+    "for i in dp:\n",
+    "    print(i)\n",
+    "\n"
+   ]
+  },
+  {
+   "source": [
+    "`filter` applied to individual rows, `nesting_level` argument used to penetrate batching"
+   ],
+   "cell_type": "markdown",
+   "metadata": {}
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": [
+      "Iterate over DataFrame batches\n[(6, 0),(7, 1),(8, 2),(9, 0),(10, 1)]\n[(11, 2),(12, 0)]\nIterate over regular batches\n[(6, 0),(7, 1),(8, 2),(9, 0),(10, 1)]\n[(11, 2),(12, 0)]\n"
+     ]
+    }
+   ],
+   "source": [
+    "dp = get_dataframes_pipe(range = 30, dataframe_size = 3)\n",
+    "dp = dp.filter(lambda x: x.i > 5)\n",
+    "dp = dp.batch(5).filter(lambda x: x.i < 13, nesting_level = 1)\n",
+    "print(\"Iterate over DataFrame batches\")\n",
+    "for i in dp:\n",
+    "    print(i)\n",
+    "\n",
+    "# Similarly works on row level for classic DataPipe elements\n",
+    "dp = get_regular_pipe(range = 30)\n",
+    "dp = dp.filter(lambda x: x[0] > 5)\n",
+    "dp = dp.batch(5).filter(lambda x: x[0] < 13, nesting_level = 1)\n",
+    "print(\"Iterate over regular batches\")\n",
+    "for i in dp:\n",
+    "    print(i)"
+   ]
+  }
+ ]
+}
\ No newline at end of file
diff --git a/torch/utils/data/datapipes/__init__.py b/torch/utils/data/datapipes/__init__.py
index 0ada5a4..f19389e 100644
--- a/torch/utils/data/datapipes/__init__.py
+++ b/torch/utils/data/datapipes/__init__.py
@@ -1,2 +1,3 @@
 from . import iter
 from . import map
+from . import dataframe
diff --git a/torch/utils/data/datapipes/dataframe/__init__.py b/torch/utils/data/datapipes/dataframe/__init__.py
new file mode 100644
index 0000000..494bf32
--- /dev/null
+++ b/torch/utils/data/datapipes/dataframe/__init__.py
@@ -0,0 +1,11 @@
+from torch.utils.data.datapipes.dataframe.dataframes import (
+    DFIterDataPipe,
+)
+from torch.utils.data.datapipes.dataframe.datapipes import (
+    DataFramesAsTuplesPipe,
+)
+
+__all__ = ['DFIterDataPipe', 'DataFramesAsTuplesPipe']
+
+# Please keep this list sorted
+assert __all__ == sorted(__all__)
diff --git a/torch/utils/data/datapipes/dataframe/dataframes.py b/torch/utils/data/datapipes/dataframe/dataframes.py
new file mode 100644
index 0000000..2eab45f
--- /dev/null
+++ b/torch/utils/data/datapipes/dataframe/dataframes.py
@@ -0,0 +1,292 @@
+from typing import Any, Dict, List
+
+from torch.utils.data import (
+    DFIterDataPipe,
+    IterDataPipe,
+    functional_datapipe,
+)
+
+from torch.utils.data.datapipes.dataframe.structures import DataChunkDF
+
+# TODO(VitalyFedyunin): Add error when two different traces get combined
+class DataFrameTracedOps(DFIterDataPipe):
+    def __init__(self, source_datapipe, output_var):
+        self.source_datapipe = source_datapipe
+        self.output_var = output_var
+
+    def __iter__(self):
+        for item in self.source_datapipe:
+            yield self.output_var.calculate_me(item)
+
+
+#  TODO(VitalyFedyunin): Extract this list from the DFIterDataPipe registred functions
+DATAPIPES_OPS = ['_dataframes_as_tuples', 'groupby', '_dataframes_filter', 'map', 'to_datapipe',
+                 'shuffle', 'concat', 'batch', '_dataframes_per_row', '_dataframes_concat', '_dataframes_shuffle']
+
+
+class Capture(object):
+    # All operations are shared across entire InitialCapture, need to figure out what if we join two captures
+    ctx: Dict[str, List[Any]]
+
+    def __init__(self):
+        self.ctx = {'operations': [], 'variables': []}
+
+    def __str__(self):
+        return self.ops_str()
+
+    def ops_str(self):
+        res = ""
+        for op in self.ctx['operations']:
+            if len(res) > 0:
+                res += "\n"
+            res += str(op)
+        return res
+
+    def __getattr__(self, attrname):
+        if attrname == 'kwarg':
+            raise Exception('no kwargs!')
+        if attrname in DATAPIPES_OPS:
+            return (self.as_datapipe()).__getattr__(attrname)
+        return CaptureGetAttr(self, attrname, ctx=self.ctx)
+
+    def __getitem__(self, key):
+        return CaptureGetItem(self, key, ctx=self.ctx)
+
+    def __setitem__(self, key, value):
+        self.ctx['operations'].append(
+            CaptureSetItem(self, key, value, ctx=self.ctx))
+
+    def __add__(self, add_val):
+        res = CaptureAdd(self, add_val, ctx=self.ctx)
+        var = CaptureVariable(res, ctx=self.ctx)
+        self.ctx['operations'].append(
+            CaptureVariableAssign(variable=var, value=res, ctx=self.ctx))
+        return var
+
+    def __sub__(self, add_val):
+        res = CaptureSub(self, add_val, ctx=self.ctx)
+        var = CaptureVariable(res, ctx=self.ctx)
+        self.ctx['operations'].append(
+            CaptureVariableAssign(variable=var, value=res, ctx=self.ctx))
+        return var
+
+    def __mul__(self, add_val):
+        res = CaptureMul(self, add_val, ctx=self.ctx)
+        var = CaptureVariable(res, ctx=self.ctx)
+        t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
+        self.ctx['operations'].append(t)
+        return var
+
+    def as_datapipe(self):
+        return DataFrameTracedOps(
+            self.ctx['variables'][0].source_datapipe, self)
+
+    def raw_iterator(self):
+        return self.as_datapipe().__iter__()
+
+    def __iter__(self):
+        return iter(self._dataframes_as_tuples())
+
+    def batch(self, batch_size=10):
+        dp = self._dataframes_per_row()._dataframes_concat(batch_size)
+        dp = dp.as_datapipe().batch(1, wrapper_class=DataChunkDF)
+        dp._dp_contains_dataframe = True
+        return dp
+
+    def groupby(self,
+                group_key_fn,
+                *,
+                buffer_size=10000,
+                group_size=None,
+                unbatch_level=0,
+                guaranteed_group_size=None,
+                drop_remaining=False):
+        if unbatch_level != 0:
+            dp = self.unbatch(unbatch_level)._dataframes_per_row()
+        else:
+            dp = self._dataframes_per_row()
+        dp = dp.as_datapipe().groupby(group_key_fn, buffer_size=buffer_size, group_size=group_size,
+                                      guaranteed_group_size=guaranteed_group_size, drop_remaining=drop_remaining)
+        return dp
+
+    def shuffle(self, *args, **kwargs):
+        return self._dataframes_shuffle(*args, **kwargs)
+
+    def filter(self, *args, **kwargs):
+        return self._dataframes_filter(*args, **kwargs)
+
+
+class CaptureF(Capture):
+    def __init__(self, ctx=None, **kwargs):
+        if ctx is None:
+            self.ctx = {'operations': [], 'variables': []}
+        self.ctx = ctx
+        self.kwargs = kwargs
+
+
+class CaptureCall(CaptureF):
+    def __str__(self):
+        return "{variable}({args},{kwargs})".format(**self.kwargs)
+
+    def execute(self):
+        return (get_val(self.kwargs['variable']))(*self.kwargs['args'], **self.kwargs['kwargs'])
+
+
+class CaptureVariableAssign(CaptureF):
+    def __str__(self):
+        return "{variable} = {value}".format(**self.kwargs)
+
+    def execute(self):
+        self.kwargs['variable'].calculated_value = self.kwargs['value'].execute()
+
+
+class CaptureVariable(Capture):
+    value = None
+    name = None
+    calculated_value = None
+    names_idx = 0
+
+    def __init__(self, value, ctx):
+        self.ctx = ctx
+        self.value = value
+        self.name = 'var_%s' % CaptureVariable.names_idx
+        CaptureVariable.names_idx += 1
+        self.ctx['variables'].append(self)
+
+    def __str__(self):
+        return self.name
+
+    def execute(self):
+        return self.calculated_value
+
+    def calculate_me(self, dataframe):
+        self.ctx['variables'][0].calculated_value = dataframe
+        for op in self.ctx['operations']:
+            op.execute()
+        return self.calculated_value
+
+
+class CaptureInitial(CaptureVariable):
+
+    def __init__(self):
+        new_ctx: Dict[str, List[Any]] = {'operations': [], 'variables': []}
+        super().__init__(None, new_ctx)
+        self.name = 'input_%s' % self.name
+
+
+class CaptureGetItem(Capture):
+    left : Capture
+    key : Any
+
+    def __init__(self, left, key, ctx):
+        self.ctx = ctx
+        self.left = left
+        self.key = key
+
+    def __str__(self):
+        return "%s[%s]" % (self.left, get_val(self.key))
+
+    def execute(self):
+        return (self.left.execute())[self.key]
+
+
+class CaptureSetItem(Capture):
+    left : Capture
+    key : Any
+    value : Capture
+
+    def __init__(self, left, key, value, ctx):
+        self.ctx = ctx
+        self.left = left
+        self.key = key
+        self.value = value
+
+    def __str__(self):
+        return "%s[%s] = %s" % (self.left, get_val(self.key), self.value)
+
+    def execute(self):
+        (self.left.execute())[
+            self.key] = self.value.execute()
+
+
+class CaptureAdd(Capture):
+    left = None
+    right = None
+
+    def __init__(self, left, right, ctx):
+        self.ctx = ctx
+        self.left = left
+        self.right = right
+
+    def __str__(self):
+        return "%s + %s" % (self.left, self.right)
+
+    def execute(self):
+        return get_val(self.left) + get_val(self.right)
+
+
+class CaptureMul(Capture):
+    left = None
+    right = None
+
+    def __init__(self, left, right, ctx):
+        self.ctx = ctx
+        self.left = left
+        self.right = right
+
+    def __str__(self):
+        return "%s * %s" % (self.left, self.right)
+
+    def execute(self):
+        return get_val(self.left) * get_val(self.right)
+
+
+class CaptureSub(Capture):
+    left = None
+    right = None
+
+    def __init__(self, left, right, ctx):
+        self.ctx = ctx
+        self.left = left
+        self.right = right
+
+    def __str__(self):
+        return "%s - %s" % (self.left, self.right)
+
+    def execute(self):
+        return get_val(self.left) - get_val(self.right)
+
+
+class CaptureGetAttr(Capture):
+    source = None
+    name: str
+
+    def __init__(self, src, name, ctx):
+        self.ctx = ctx
+        self.src = src
+        self.name = name
+
+    def __str__(self):
+        return "%s.%s" % (self.src, self.name)
+
+    def execute(self):
+        val = get_val(self.src)
+        return getattr(val, self.name)
+
+
+def get_val(capture):
+    if isinstance(capture, Capture):
+        return capture.execute()
+    elif isinstance(capture, str):
+        return '"%s"' % capture
+    else:
+        return capture
+
+
+@functional_datapipe('trace_as_dataframe')
+class DataFrameTracer(CaptureInitial, IterDataPipe):
+    source_datapipe = None
+
+    def __init__(self, source_datapipe):
+        super().__init__()
+        self.source_datapipe = source_datapipe
diff --git a/torch/utils/data/datapipes/dataframe/datapipes.py b/torch/utils/data/datapipes/dataframe/datapipes.py
new file mode 100644
index 0000000..f76189a
--- /dev/null
+++ b/torch/utils/data/datapipes/dataframe/datapipes.py
@@ -0,0 +1,137 @@
+import random
+
+from torch.utils.data import (
+    DFIterDataPipe,
+    IterDataPipe,
+    functional_datapipe,
+)
+
+try:
+    import pandas  # type: ignore[import]
+    # pandas used only for prototyping, will be shortly replaced with TorchArrow
+    WITH_PANDAS = True
+except ImportError:
+    WITH_PANDAS = False
+
+
+@functional_datapipe('_dataframes_as_tuples')
+class DataFramesAsTuplesPipe(IterDataPipe):
+    def __init__(self, source_datapipe):
+        self.source_datapipe = source_datapipe
+
+    def __iter__(self):
+        for df in self.source_datapipe:
+            for record in df.to_records(index=False):
+                yield record
+
+
+@functional_datapipe('_dataframes_per_row', enable_df_api_tracing=True)
+class PerRowDataFramesPipe(DFIterDataPipe):
+    def __init__(self, source_datapipe):
+        self.source_datapipe = source_datapipe
+
+    def __iter__(self):
+        for df in self.source_datapipe:
+            for i in range(len(df.index)):
+                yield df[i:i + 1]
+
+
+@functional_datapipe('_dataframes_concat', enable_df_api_tracing=True)
+class ConcatDataFramesPipe(DFIterDataPipe):
+    def __init__(self, source_datapipe, batch=3):
+        self.source_datapipe = source_datapipe
+        self.batch = batch
+        if not WITH_PANDAS:
+            Exception('DataFrames prototype requires pandas to function')
+
+    def __iter__(self):
+        buffer = []
+        for df in self.source_datapipe:
+            buffer.append(df)
+            if len(buffer) == self.batch:
+                yield pandas.concat(buffer)
+                buffer = []
+        if len(buffer):
+            yield pandas.concat(buffer)
+
+
+@functional_datapipe('_dataframes_shuffle', enable_df_api_tracing=True)
+class ShuffleDataFramesPipe(DFIterDataPipe):
+    def __init__(self, source_datapipe):
+        self.source_datapipe = source_datapipe
+        if not WITH_PANDAS:
+            Exception('DataFrames prototype requires pandas to function')
+
+    def __iter__(self):
+        size = None
+        all_buffer = []
+        for df in self.source_datapipe:
+            if size is None:
+                size = len(df.index)
+            for i in range(len(df.index)):
+                all_buffer.append(df[i:i + 1])
+        random.shuffle(all_buffer)
+        buffer = []
+        for df in all_buffer:
+            buffer.append(df)
+            if len(buffer) == size:
+                yield pandas.concat(buffer)
+                buffer = []
+        if len(buffer):
+            yield pandas.concat(buffer)
+
+
+@functional_datapipe('_dataframes_filter', enable_df_api_tracing=True)
+class FilterDataFramesPipe(DFIterDataPipe):
+    def __init__(self, source_datapipe, filter_fn):
+        self.source_datapipe = source_datapipe
+        self.filter_fn = filter_fn
+        if not WITH_PANDAS:
+            Exception('DataFrames prototype requires pandas to function')
+
+    def __iter__(self):
+        size = None
+        all_buffer = []
+        filter_res = []
+        for df in self.source_datapipe:
+            if size is None:
+                size = len(df.index)
+            for i in range(len(df.index)):
+                all_buffer.append(df[i:i + 1])
+                filter_res.append(self.filter_fn(df.iloc[i]))
+
+        buffer = []
+        for df, res in zip(all_buffer, filter_res):
+            if res:
+                buffer.append(df)
+                if len(buffer) == size:
+                    yield pandas.concat(buffer)
+                    buffer = []
+        if len(buffer):
+            yield pandas.concat(buffer)
+
+
+@functional_datapipe('_to_dataframes_pipe', enable_df_api_tracing=True)
+class ExampleAggregateAsDataFrames(DFIterDataPipe):
+    def __init__(self, source_datapipe, dataframe_size=10, columns=None):
+        self.source_datapipe = source_datapipe
+        self.columns = columns
+        self.dataframe_size = dataframe_size
+        if not WITH_PANDAS:
+            Exception('DataFrames prototype requires pandas to function')
+
+    def _as_list(self, item):
+        try:
+            return list(item)
+        except Exception:  # TODO(VitalyFedyunin): Replace with better iterable exception
+            return [item]
+
+    def __iter__(self):
+        aggregate = []
+        for item in self.source_datapipe:
+            aggregate.append(self._as_list(item))
+            if len(aggregate) == self.dataframe_size:
+                yield pandas.DataFrame(aggregate, columns=self.columns)
+                aggregate = []
+        if len(aggregate) > 0:
+            yield pandas.DataFrame(aggregate, columns=self.columns)
diff --git a/torch/utils/data/datapipes/dataframe/structures.py b/torch/utils/data/datapipes/dataframe/structures.py
new file mode 100644
index 0000000..c822f89
--- /dev/null
+++ b/torch/utils/data/datapipes/dataframe/structures.py
@@ -0,0 +1,20 @@
+from torch.utils.data import (
+    DataChunk,
+)
+
+class DataChunkDF(DataChunk):
+    """
+        DataChunkDF iterating over individual items inside of DataFrame containers,
+        to access DataFrames user `raw_iterator`
+    """
+
+    def __iter__(self):
+        for df in self.items:
+            for record in df.to_records(index=False):
+                yield record
+
+    def __len__(self):
+        total_len = 0
+        for df in self.items:
+            total_len += len(df)
+        return total_len
diff --git a/torch/utils/data/datapipes/iter/__init__.py b/torch/utils/data/datapipes/iter/__init__.py
index 26d715d..0b96cfd 100644
--- a/torch/utils/data/datapipes/iter/__init__.py
+++ b/torch/utils/data/datapipes/iter/__init__.py
@@ -54,6 +54,7 @@
            'BucketBatcher',
            'Collator',
            'Concater',
+           'DFIterDataPipe',
            'Demultiplexer',
            'FileLister',
            'FileLoader',
diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py
index d90ad08..f848681 100644
--- a/torch/utils/data/datapipes/iter/grouping.py
+++ b/torch/utils/data/datapipes/iter/grouping.py
@@ -60,6 +60,7 @@
                  batch_size: int,
                  drop_last: bool = False,
                  unbatch_level: int = 0,
+                 wrapper_class=DataChunk,
                  ) -> None:
         assert batch_size > 0, "Batch size is required to be larger than 0!"
         super().__init__()
@@ -71,7 +72,7 @@
         self.batch_size = batch_size
         self.drop_last = drop_last
         self.length = None
-        self.wrapper_class = DataChunk
+        self.wrapper_class = wrapper_class
 
     def __iter__(self) -> Iterator[DataChunk]:
         batch: List = []
diff --git a/torch/utils/data/datapipes/iter/selecting.py b/torch/utils/data/datapipes/iter/selecting.py
index f1889e5..2f0a158 100644
--- a/torch/utils/data/datapipes/iter/selecting.py
+++ b/torch/utils/data/datapipes/iter/selecting.py
@@ -1,6 +1,14 @@
 import warnings
-from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk
-from typing import Callable, TypeVar, Iterator, Optional, Tuple, Dict
+from typing import Callable, Dict, Iterator, Optional, Tuple, TypeVar
+
+from torch.utils.data import DataChunk, IterDataPipe, functional_datapipe
+
+try:
+    import pandas  # type: ignore[import]
+    # pandas used only for prototyping, will be shortly replaced with TorchArrow
+    WITH_PANDAS = True
+except ImportError:
+    WITH_PANDAS = False
 
 T_co = TypeVar('T_co', covariant=True)
 
@@ -91,12 +99,27 @@
 
     def _returnIfTrue(self, data):
         condition = self.filter_fn(data, *self.args, **self.kwargs)
+        if WITH_PANDAS:
+            if isinstance(condition, pandas.core.series.Series):
+                # We are operatring on DataFrames filter here
+                result = []
+                for idx, mask in enumerate(condition):
+                    if mask:
+                        result.append(data[idx:idx + 1])
+                if len(result):
+                    return pandas.concat(result)
+                else:
+                    return None
+
         if not isinstance(condition, bool):
-            raise ValueError("Boolean output is required for `filter_fn` of FilterIterDataPipe")
+            raise ValueError("Boolean output is required for `filter_fn` of FilterIterDataPipe, got", type(condition))
         if condition:
             return data
 
     def _isNonEmpty(self, data):
+        if WITH_PANDAS:
+            if isinstance(data, pandas.core.frame.DataFrame):
+                return True
         r = data is not None and \
             not (isinstance(data, list) and len(data) == 0 and self.drop_empty_batches)
         return r
diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py
index 50488d1..c968fdf 100644
--- a/torch/utils/data/dataset.py
+++ b/torch/utils/data/dataset.py
@@ -24,6 +24,11 @@
 T_co = TypeVar('T_co', covariant=True)
 T = TypeVar('T')
 
+UNTRACABLE_DATAFRAME_PIPES = ['batch',  # As it returns DataChunks
+                              'groupby',   # As it returns DataChunks
+                              '_dataframes_as_tuples',  # As it unpacks DF
+                              'trace_as_dataframe',  # As it used to mark DF for tracing
+                              ]
 
 class DataChunk(list, Generic[T]):
     def __init__(self, items):
@@ -82,13 +87,20 @@
         cls.functions[function_name] = function
 
     @classmethod
-    def register_datapipe_as_function(cls, function_name, cls_to_register):
+    def register_datapipe_as_function(cls, function_name, cls_to_register, enable_df_api_tracing=False):
         if function_name in cls.functions:
             raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name))
 
-        def class_function(cls, source_dp, *args, **kwargs):
-            return cls(source_dp, *args, **kwargs)
-        function = functools.partial(class_function, cls_to_register)
+        def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
+            result_pipe = cls(source_dp, *args, **kwargs)
+            if isinstance(result_pipe, Dataset):
+                if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe):
+                    if function_name not in UNTRACABLE_DATAFRAME_PIPES:
+                        result_pipe = result_pipe.trace_as_dataframe()
+
+            return result_pipe
+
+        function = functools.partial(class_function, cls_to_register, enable_df_api_tracing)
         cls.functions[function_name] = function
 
 
@@ -227,6 +239,9 @@
             raise Exception("Attempt to override existing reduce_ex_hook")
         IterableDataset.reduce_ex_hook = hook_fn
 
+class DFIterDataPipe(IterableDataset):
+    def _is_dfpipe(self):
+        return True
 
 class TensorDataset(Dataset[Tuple[Tensor, ...]]):
     r"""Dataset wrapping tensors.