[WIP] Example of DataPipes and DataFrames integration (#60840)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60840
Test Plan: Imported from OSS
Reviewed By: wenleix, ejguan
Differential Revision: D29461080
Pulled By: VitalyFedyunin
fbshipit-source-id: 4909394dcd39e97ee49b699fda542b311b7e0d82
diff --git a/test/test_datapipe.py b/test/test_datapipe.py
index c82f81b..4d5bd7c 100644
--- a/test/test_datapipe.py
+++ b/test/test_datapipe.py
@@ -65,6 +65,13 @@
HAS_DILL = False
skipIfNoDill = skipIf(not HAS_DILL, "no dill")
+try:
+ import pandas # type: ignore[import] # noqa: F401 F403
+ HAS_PANDAS = True
+except ImportError:
+ HAS_PANDAS = False
+skipIfNoDataFrames = skipIf(not HAS_PANDAS, "no dataframes (pandas)")
+
T_co = TypeVar("T_co", covariant=True)
@@ -393,6 +400,64 @@
self.assertEqual(source_numbers, list(n))
+class TestDataFramesPipes(TestCase):
+ """
+ Most of test will fail if pandas instaled, but no dill available.
+ Need to rework them to avoid multiple skips.
+ """
+ def _get_datapipe(self, range=10, dataframe_size=7):
+ return NumbersDataset(range) \
+ .map(lambda i: (i, i % 3))
+
+ def _get_dataframes_pipe(self, range=10, dataframe_size=7):
+ return NumbersDataset(range) \
+ .map(lambda i: (i, i % 3)) \
+ ._to_dataframes_pipe(
+ columns=['i', 'j'],
+ dataframe_size=dataframe_size)
+
+ @skipIfNoDataFrames
+ @skipIfNoDill # TODO(VitalyFedyunin): Decouple tests from dill by avoiding lambdas in map
+ def test_capture(self):
+ dp_numbers = self._get_datapipe().map(lambda x: (x[0], x[1], x[1] + 3 * x[0]))
+ df_numbers = self._get_dataframes_pipe()
+ df_numbers['k'] = df_numbers['j'] + df_numbers.i * 3
+ self.assertEqual(list(dp_numbers), list(df_numbers))
+
+ @skipIfNoDataFrames
+ @skipIfNoDill
+ def test_shuffle(self):
+ # With non-zero (but extremely low) probability (when shuffle do nothing),
+ # this test fails, so feel free to restart
+ df_numbers = self._get_dataframes_pipe(range=1000).shuffle()
+ dp_numbers = self._get_datapipe(range=1000)
+ df_result = [tuple(item) for item in df_numbers]
+ self.assertNotEqual(list(dp_numbers), df_result)
+ self.assertEqual(list(dp_numbers), sorted(df_result))
+
+ @skipIfNoDataFrames
+ @skipIfNoDill
+ def test_batch(self):
+ df_numbers = self._get_dataframes_pipe(range=100).batch(8)
+ df_numbers_list = list(df_numbers)
+ last_batch = df_numbers_list[-1]
+ self.assertEqual(4, len(last_batch))
+ unpacked_batch = [tuple(row) for row in last_batch]
+ self.assertEqual([(96, 0), (97, 1), (98, 2), (99, 0)], unpacked_batch)
+
+ @skipIfNoDataFrames
+ @skipIfNoDill
+ def test_unbatch(self):
+ df_numbers = self._get_dataframes_pipe(range=100).batch(8).batch(3)
+ dp_numbers = self._get_datapipe(range=100)
+ self.assertEqual(list(dp_numbers), list(df_numbers.unbatch(2)))
+
+ @skipIfNoDataFrames
+ @skipIfNoDill
+ def test_filter(self):
+ df_numbers = self._get_dataframes_pipe(range=10).filter(lambda x: x.i > 5)
+ self.assertEqual([(6, 0), (7, 1), (8, 2), (9, 0)], list(df_numbers))
+
class FileLoggerSimpleHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
def __init__(self, *args, logfile=None, **kwargs):
self.__loggerHandle = None
diff --git a/tools/linter/clang_tidy/run.py b/tools/linter/clang_tidy/run.py
index f84d225..9e71333 100644
--- a/tools/linter/clang_tidy/run.py
+++ b/tools/linter/clang_tidy/run.py
@@ -421,7 +421,7 @@
i += 1
ranges[-1][1] = added_line_nos[-1]
- files[file.path].append(*ranges)
+ files[file.path] += ranges
return dict(files)
diff --git a/torch/utils/data/__init__.py b/torch/utils/data/__init__.py
index ac0c763..167a860 100644
--- a/torch/utils/data/__init__.py
+++ b/torch/utils/data/__init__.py
@@ -14,6 +14,7 @@
DataChunk,
Dataset,
Dataset as MapDataPipe,
+ DFIterDataPipe,
IterableDataset,
IterableDataset as IterDataPipe,
Subset,
diff --git a/torch/utils/data/_decorator.py b/torch/utils/data/_decorator.py
index 6e64fd5..53dce59 100644
--- a/torch/utils/data/_decorator.py
+++ b/torch/utils/data/_decorator.py
@@ -11,8 +11,14 @@
class functional_datapipe(object):
name: str
- def __init__(self, name: str) -> None:
+ def __init__(self, name: str, enable_df_api_tracing=False) -> None:
+ """
+ Args:
+ enable_df_api_tracing - if set, any returned DataPipe would accept
+ DataFrames API in tracing mode.
+ """
self.name = name
+ self.enable_df_api_tracing = enable_df_api_tracing
def __call__(self, cls):
if issubclass(cls, IterDataPipe):
@@ -25,9 +31,9 @@
not (hasattr(cls, '__self__') and
isinstance(cls.__self__, non_deterministic)):
raise TypeError('`functional_datapipe` can only decorate IterDataPipe')
- IterDataPipe.register_datapipe_as_function(self.name, cls)
+ IterDataPipe.register_datapipe_as_function(self.name, cls, enable_df_api_tracing=self.enable_df_api_tracing)
elif issubclass(cls, MapDataPipe):
- MapDataPipe.register_datapipe_as_function(self.name, cls)
+ MapDataPipe.register_datapipe_as_function(self.name, cls, enable_df_api_tracing=self.enable_df_api_tracing)
return cls
diff --git a/torch/utils/data/dataframes_pipes.ipynb b/torch/utils/data/dataframes_pipes.ipynb
new file mode 100644
index 0000000..74e36ce
--- /dev/null
+++ b/torch/utils/data/dataframes_pipes.ipynb
@@ -0,0 +1,523 @@
+{
+ "metadata": {
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.10"
+ },
+ "orig_nbformat": 2,
+ "kernelspec": {
+ "name": "python3610jvsc74a57bd0eb5e09632d6ea1cbf3eb9da7e37b7cf581db5ed13074b21cc44e159dc62acdab",
+ "display_name": "Python 3.6.10 64-bit ('dataloader': conda)"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2,
+ "cells": [
+ {
+ "source": [
+ "## \\[RFC\\] How DataFrames (DF) and DataPipes (DP) work together"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from importlib import reload\n",
+ "import torch\n",
+ "reload(torch)\n",
+ "from torch.utils.data import IterDataPipe"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Example IterDataPipe\n",
+ "class ExampleIterPipe(IterDataPipe):\n",
+ " def __init__(self, range = 20):\n",
+ " self.range = range\n",
+ " def __iter__(self):\n",
+ " for i in range(self.range):\n",
+ " yield i\n",
+ "\n",
+ "def get_dataframes_pipe(range = 10, dataframe_size = 7):\n",
+ " return ExampleIterPipe(range = range).map(lambda i: (i, i % 3))._to_dataframes_pipe(columns = ['i','j'], dataframe_size = dataframe_size)\n",
+ "\n",
+ "def get_regular_pipe(range = 10):\n",
+ " return ExampleIterPipe(range = range).map(lambda i: (i, i % 3))\n"
+ ]
+ },
+ {
+ "source": [
+ "Doesn't matter how DF composed internally, iterator over DF Pipe gives single rows to user. This is similar to regular DataPipe."
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "DataFrames Pipe\n(0, 0)\n(1, 1)\n(2, 2)\n(3, 0)\n(4, 1)\n(5, 2)\n(6, 0)\n(7, 1)\n(8, 2)\n(9, 0)\nRegular DataPipe\n(0, 0)\n(1, 1)\n(2, 2)\n(3, 0)\n(4, 1)\n(5, 2)\n(6, 0)\n(7, 1)\n(8, 2)\n(9, 0)\n"
+ ]
+ }
+ ],
+ "source": [
+ "print('DataFrames Pipe')\n",
+ "dp = get_dataframes_pipe()\n",
+ "for i in dp:\n",
+ " print(i)\n",
+ "\n",
+ "print('Regular DataPipe')\n",
+ "dp = get_regular_pipe()\n",
+ "for i in dp:\n",
+ " print(i)"
+ ]
+ },
+ {
+ "source": [
+ "You can iterate over raw DF using `raw_iterator`"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ " i j\n0 0 0\n1 1 1\n2 2 2\n3 3 0\n4 4 1\n5 5 2\n6 6 0\n i j\n0 7 1\n1 8 2\n2 9 0\n"
+ ]
+ }
+ ],
+ "source": [
+ "dp = get_dataframes_pipe()\n",
+ "for i in dp.raw_iterator():\n",
+ " print(i)"
+ ]
+ },
+ {
+ "source": [
+ "Operations over DF Pipe is captured"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "var_3 = input_var_2.i * 100\nvar_4 = var_3 + input_var_2.j\nvar_5 = var_4 - 2.7\ninput_var_2[\"y\"] = var_5\n"
+ ]
+ }
+ ],
+ "source": [
+ "dp = get_dataframes_pipe(dataframe_size = 3)\n",
+ "dp['y'] = dp.i * 100 + dp.j - 2.7\n",
+ "print(dp.ops_str())\n"
+ ]
+ },
+ {
+ "source": [
+ "Captured operations executed on `__next__` calls of constructed DataPipe"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ " i j y\n0 0 0 -2.7\n1 1 1 98.3\n2 2 2 199.3\n i j y\n0 3 0 297.3\n1 4 1 398.3\n2 5 2 499.3\n i j y\n0 6 0 597.3\n1 7 1 698.3\n2 8 2 799.3\n i j y\n0 9 0 897.3\n"
+ ]
+ }
+ ],
+ "source": [
+ "dp = get_dataframes_pipe(dataframe_size = 3)\n",
+ "dp['y'] = dp.i * 100 + dp.j - 2.7\n",
+ "for i in dp.raw_iterator():\n",
+ " print(i)"
+ ]
+ },
+ {
+ "source": [
+ "`shuffle` of DataFramePipe effects rows in individual manner"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Raw DataFrames iterator\n i j\n2 8 2\n2 2 2\n2 5 2\n i j\n1 4 1\n1 1 1\n0 3 0\n i j\n1 7 1\n0 9 0\n0 6 0\n i j\n0 0 0\nRegular DataFrames iterator\n(1, 1)\n(5, 2)\n(8, 2)\n(9, 0)\n(7, 1)\n(6, 0)\n(3, 0)\n(4, 1)\n(0, 0)\n(2, 2)\nRegular iterator\n(5, 2)\n(6, 0)\n(0, 0)\n(9, 0)\n(3, 0)\n(1, 1)\n(2, 2)\n(8, 2)\n(4, 1)\n(7, 1)\n"
+ ]
+ }
+ ],
+ "source": [
+ "dp = get_dataframes_pipe(dataframe_size = 3)\n",
+ "dp = dp.shuffle()\n",
+ "print('Raw DataFrames iterator')\n",
+ "for i in dp.raw_iterator():\n",
+ " print(i)\n",
+ "\n",
+ "print('Regular DataFrames iterator')\n",
+ "for i in dp:\n",
+ " print(i)\n",
+ "\n",
+ "\n",
+ "# this is similar to shuffle of regular DataPipe\n",
+ "dp = get_regular_pipe()\n",
+ "dp = dp.shuffle()\n",
+ "print('Regular iterator')\n",
+ "for i in dp:\n",
+ " print(i)"
+ ]
+ },
+ {
+ "source": [
+ "You can continue mixing DF and DP operations"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ " i j y\n0 -17 -17 -197000.0\n1 -13 -16 3813000.0\n0 -11 -17 5803000.0\n i j y\n2 -12 -15 4823000.0\n1 -10 -16 6813000.0\n1 -16 -16 813000.0\n i j y\n0 -8 -17 8803000.0\n2 -9 -15 7823000.0\n0 -14 -17 2803000.0\n i j y\n2 -15 -15 1823000.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "dp = get_dataframes_pipe(dataframe_size = 3)\n",
+ "dp['y'] = dp.i * 100 + dp.j - 2.7\n",
+ "dp = dp.shuffle()\n",
+ "dp = dp - 17\n",
+ "dp['y'] = dp.y * 10000\n",
+ "for i in dp.raw_iterator():\n",
+ " print(i)"
+ ]
+ },
+ {
+ "source": [
+ "Batching combines everything into `list` it is possible to nest `list`s. List may have any number of DataFrames as soon as total number of rows equal to batch size."
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Iterate over DataFrame batches\n[(6, 0),(0, 0)]\n[(4, 1),(1, 1)]\n[(2, 2),(9, 0)]\n[(3, 0),(5, 2)]\n[(7, 1),(8, 2)]\nIterate over regular batches\n[(1, 1),(4, 1)]\n[(2, 2),(3, 0)]\n[(6, 0),(7, 1)]\n[(8, 2),(0, 0)]\n[(5, 2),(9, 0)]\n"
+ ]
+ }
+ ],
+ "source": [
+ "dp = get_dataframes_pipe(dataframe_size = 3)\n",
+ "dp = dp.shuffle()\n",
+ "dp = dp.batch(2)\n",
+ "print(\"Iterate over DataFrame batches\")\n",
+ "for i,v in enumerate(dp):\n",
+ " print(v)\n",
+ "\n",
+ "# this is similar to batching of regular DataPipe\n",
+ "dp = get_regular_pipe()\n",
+ "dp = dp.shuffle()\n",
+ "dp = dp.batch(2)\n",
+ "print(\"Iterate over regular batches\")\n",
+ "for i in dp:\n",
+ " print(i)"
+ ]
+ },
+ {
+ "source": [
+ "Some details about internal storage of batched DataFrames and how they are iterated"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Type: <class 'torch.utils.data.datapipes.iter.dataframes.DataChunkDF'>\n",
+ "As string: [(0, 0),(3, 0)]\n",
+ "Iterated regularly:\n",
+ "-- batch start --\n",
+ "(0, 0)\n",
+ "(3, 0)\n",
+ "-- batch end --\n",
+ "Iterated in inner format (for developers):\n",
+ "-- df batch start --\n",
+ " i j\n",
+ "0 0 0\n",
+ "0 3 0\n",
+ "-- df batch end --\n",
+ "Type: <class 'torch.utils.data.datapipes.iter.dataframes.DataChunkDF'>\n",
+ "As string: [(6, 0),(1, 1)]\n",
+ "Iterated regularly:\n",
+ "-- batch start --\n",
+ "(6, 0)\n",
+ "(1, 1)\n",
+ "-- batch end --\n",
+ "Iterated in inner format (for developers):\n",
+ "-- df batch start --\n",
+ " i j\n",
+ "0 6 0\n",
+ "1 1 1\n",
+ "-- df batch end --\n",
+ "Type: <class 'torch.utils.data.datapipes.iter.dataframes.DataChunkDF'>\n",
+ "As string: [(9, 0),(4, 1)]\n",
+ "Iterated regularly:\n",
+ "-- batch start --\n",
+ "(9, 0)\n",
+ "(4, 1)\n",
+ "-- batch end --\n",
+ "Iterated in inner format (for developers):\n",
+ "-- df batch start --\n",
+ " i j\n",
+ "0 9 0\n",
+ "1 4 1\n",
+ "-- df batch end --\n",
+ "Type: <class 'torch.utils.data.datapipes.iter.dataframes.DataChunkDF'>\n",
+ "As string: [(5, 2),(2, 2)]\n",
+ "Iterated regularly:\n",
+ "-- batch start --\n",
+ "(5, 2)\n",
+ "(2, 2)\n",
+ "-- batch end --\n",
+ "Iterated in inner format (for developers):\n",
+ "-- df batch start --\n",
+ " i j\n",
+ "2 5 2\n",
+ "2 2 2\n",
+ "-- df batch end --\n",
+ "Type: <class 'torch.utils.data.datapipes.iter.dataframes.DataChunkDF'>\n",
+ "As string: [(8, 2),(7, 1)]\n",
+ "Iterated regularly:\n",
+ "-- batch start --\n",
+ "(8, 2)\n",
+ "(7, 1)\n",
+ "-- batch end --\n",
+ "Iterated in inner format (for developers):\n",
+ "-- df batch start --\n",
+ " i j\n",
+ "2 8 2\n",
+ "1 7 1\n",
+ "-- df batch end --\n"
+ ]
+ }
+ ],
+ "source": [
+ "dp = get_dataframes_pipe(dataframe_size = 3)\n",
+ "dp = dp.shuffle()\n",
+ "dp = dp.batch(2)\n",
+ "for i in dp:\n",
+ " print(\"Type: \", type(i))\n",
+ " print(\"As string: \", i)\n",
+ " print(\"Iterated regularly:\")\n",
+ " print('-- batch start --')\n",
+ " for item in i:\n",
+ " print(item)\n",
+ " print('-- batch end --')\n",
+ " print(\"Iterated in inner format (for developers):\")\n",
+ " print('-- df batch start --')\n",
+ " for item in i.raw_iterator():\n",
+ " print(item)\n",
+ " print('-- df batch end --') "
+ ]
+ },
+ {
+ "source": [
+ "`concat` should work only of DF with same schema, this code should produce an error "
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# TODO!\n",
+ "# dp0 = get_dataframes_pipe(range = 8, dataframe_size = 4)\n",
+ "# dp = get_dataframes_pipe(range = 6, dataframe_size = 3)\n",
+ "# dp['y'] = dp.i * 100 + dp.j - 2.7\n",
+ "# dp = dp.concat(dp0)\n",
+ "# for i,v in enumerate(dp.raw_iterator()):\n",
+ "# print(v)"
+ ]
+ },
+ {
+ "source": [
+ "`unbatch` of `list` with DataFrame works similarly to regular unbatch.\n",
+ "Note: DataFrame sizes might change"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "error",
+ "ename": "AttributeError",
+ "evalue": "",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m<ipython-input-12-fa80e9c68655>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;31m# Here is bug with unbatching which doesn't detect DF type.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mdp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'z'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m100\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraw_iterator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/dataset/pytorch/torch/utils/data/dataset.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, attribute_name)\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 224\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mAttributeError\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 225\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__reduce_ex__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mAttributeError\u001b[0m: "
+ ]
+ }
+ ],
+ "source": [
+ "dp = get_dataframes_pipe(range = 18, dataframe_size = 3)\n",
+ "dp['y'] = dp.i * 100 + dp.j - 2.7\n",
+ "dp = dp.batch(5).batch(3).batch(1).unbatch(unbatch_level = 3)\n",
+ "\n",
+ "# Here is bug with unbatching which doesn't detect DF type.\n",
+ "dp['z'] = dp.y - 100\n",
+ "\n",
+ "for i in dp.raw_iterator():\n",
+ " print(i)"
+ ]
+ },
+ {
+ "source": [
+ "`map` applied to individual rows, `nesting_level` argument used to penetrate batching"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Iterate over DataFrame batches\n[(1111000, 1111000),(1112000, 1112000),(1113000, 1113000),(1114000, 1111000),(1115000, 1112000)]\n[(1116000, 1113000),(1117000, 1111000),(1118000, 1112000),(1119000, 1113000),(1120000, 1111000)]\nIterate over regular batches\n[(1111000, 0),(1112000, 1),(1113000, 2),(1114000, 0),(1115000, 1)]\n[(1116000, 2),(1117000, 0),(1118000, 1),(1119000, 2),(1120000, 0)]\n"
+ ]
+ }
+ ],
+ "source": [
+ "dp = get_dataframes_pipe(range = 10, dataframe_size = 3)\n",
+ "dp = dp.map(lambda x: x + 1111)\n",
+ "dp = dp.batch(5).map(lambda x: x * 1000, nesting_level = 1)\n",
+ "print(\"Iterate over DataFrame batches\")\n",
+ "for i in dp:\n",
+ " print(i)\n",
+ "\n",
+ "# Similarly works on row level for classic DataPipe elements\n",
+ "dp = get_regular_pipe(range = 10)\n",
+ "dp = dp.map(lambda x: (x[0] + 1111, x[1]))\n",
+ "dp = dp.batch(5).map(lambda x: (x[0] * 1000, x[1]), nesting_level = 1)\n",
+ "print(\"Iterate over regular batches\")\n",
+ "for i in dp:\n",
+ " print(i)\n",
+ "\n"
+ ]
+ },
+ {
+ "source": [
+ "`filter` applied to individual rows, `nesting_level` argument used to penetrate batching"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Iterate over DataFrame batches\n[(6, 0),(7, 1),(8, 2),(9, 0),(10, 1)]\n[(11, 2),(12, 0)]\nIterate over regular batches\n[(6, 0),(7, 1),(8, 2),(9, 0),(10, 1)]\n[(11, 2),(12, 0)]\n"
+ ]
+ }
+ ],
+ "source": [
+ "dp = get_dataframes_pipe(range = 30, dataframe_size = 3)\n",
+ "dp = dp.filter(lambda x: x.i > 5)\n",
+ "dp = dp.batch(5).filter(lambda x: x.i < 13, nesting_level = 1)\n",
+ "print(\"Iterate over DataFrame batches\")\n",
+ "for i in dp:\n",
+ " print(i)\n",
+ "\n",
+ "# Similarly works on row level for classic DataPipe elements\n",
+ "dp = get_regular_pipe(range = 30)\n",
+ "dp = dp.filter(lambda x: x[0] > 5)\n",
+ "dp = dp.batch(5).filter(lambda x: x[0] < 13, nesting_level = 1)\n",
+ "print(\"Iterate over regular batches\")\n",
+ "for i in dp:\n",
+ " print(i)"
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/torch/utils/data/datapipes/__init__.py b/torch/utils/data/datapipes/__init__.py
index 0ada5a4..f19389e 100644
--- a/torch/utils/data/datapipes/__init__.py
+++ b/torch/utils/data/datapipes/__init__.py
@@ -1,2 +1,3 @@
from . import iter
from . import map
+from . import dataframe
diff --git a/torch/utils/data/datapipes/dataframe/__init__.py b/torch/utils/data/datapipes/dataframe/__init__.py
new file mode 100644
index 0000000..494bf32
--- /dev/null
+++ b/torch/utils/data/datapipes/dataframe/__init__.py
@@ -0,0 +1,11 @@
+from torch.utils.data.datapipes.dataframe.dataframes import (
+ DFIterDataPipe,
+)
+from torch.utils.data.datapipes.dataframe.datapipes import (
+ DataFramesAsTuplesPipe,
+)
+
+__all__ = ['DFIterDataPipe', 'DataFramesAsTuplesPipe']
+
+# Please keep this list sorted
+assert __all__ == sorted(__all__)
diff --git a/torch/utils/data/datapipes/dataframe/dataframes.py b/torch/utils/data/datapipes/dataframe/dataframes.py
new file mode 100644
index 0000000..2eab45f
--- /dev/null
+++ b/torch/utils/data/datapipes/dataframe/dataframes.py
@@ -0,0 +1,292 @@
+from typing import Any, Dict, List
+
+from torch.utils.data import (
+ DFIterDataPipe,
+ IterDataPipe,
+ functional_datapipe,
+)
+
+from torch.utils.data.datapipes.dataframe.structures import DataChunkDF
+
+# TODO(VitalyFedyunin): Add error when two different traces get combined
+class DataFrameTracedOps(DFIterDataPipe):
+ def __init__(self, source_datapipe, output_var):
+ self.source_datapipe = source_datapipe
+ self.output_var = output_var
+
+ def __iter__(self):
+ for item in self.source_datapipe:
+ yield self.output_var.calculate_me(item)
+
+
+# TODO(VitalyFedyunin): Extract this list from the DFIterDataPipe registred functions
+DATAPIPES_OPS = ['_dataframes_as_tuples', 'groupby', '_dataframes_filter', 'map', 'to_datapipe',
+ 'shuffle', 'concat', 'batch', '_dataframes_per_row', '_dataframes_concat', '_dataframes_shuffle']
+
+
+class Capture(object):
+ # All operations are shared across entire InitialCapture, need to figure out what if we join two captures
+ ctx: Dict[str, List[Any]]
+
+ def __init__(self):
+ self.ctx = {'operations': [], 'variables': []}
+
+ def __str__(self):
+ return self.ops_str()
+
+ def ops_str(self):
+ res = ""
+ for op in self.ctx['operations']:
+ if len(res) > 0:
+ res += "\n"
+ res += str(op)
+ return res
+
+ def __getattr__(self, attrname):
+ if attrname == 'kwarg':
+ raise Exception('no kwargs!')
+ if attrname in DATAPIPES_OPS:
+ return (self.as_datapipe()).__getattr__(attrname)
+ return CaptureGetAttr(self, attrname, ctx=self.ctx)
+
+ def __getitem__(self, key):
+ return CaptureGetItem(self, key, ctx=self.ctx)
+
+ def __setitem__(self, key, value):
+ self.ctx['operations'].append(
+ CaptureSetItem(self, key, value, ctx=self.ctx))
+
+ def __add__(self, add_val):
+ res = CaptureAdd(self, add_val, ctx=self.ctx)
+ var = CaptureVariable(res, ctx=self.ctx)
+ self.ctx['operations'].append(
+ CaptureVariableAssign(variable=var, value=res, ctx=self.ctx))
+ return var
+
+ def __sub__(self, add_val):
+ res = CaptureSub(self, add_val, ctx=self.ctx)
+ var = CaptureVariable(res, ctx=self.ctx)
+ self.ctx['operations'].append(
+ CaptureVariableAssign(variable=var, value=res, ctx=self.ctx))
+ return var
+
+ def __mul__(self, add_val):
+ res = CaptureMul(self, add_val, ctx=self.ctx)
+ var = CaptureVariable(res, ctx=self.ctx)
+ t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
+ self.ctx['operations'].append(t)
+ return var
+
+ def as_datapipe(self):
+ return DataFrameTracedOps(
+ self.ctx['variables'][0].source_datapipe, self)
+
+ def raw_iterator(self):
+ return self.as_datapipe().__iter__()
+
+ def __iter__(self):
+ return iter(self._dataframes_as_tuples())
+
+ def batch(self, batch_size=10):
+ dp = self._dataframes_per_row()._dataframes_concat(batch_size)
+ dp = dp.as_datapipe().batch(1, wrapper_class=DataChunkDF)
+ dp._dp_contains_dataframe = True
+ return dp
+
+ def groupby(self,
+ group_key_fn,
+ *,
+ buffer_size=10000,
+ group_size=None,
+ unbatch_level=0,
+ guaranteed_group_size=None,
+ drop_remaining=False):
+ if unbatch_level != 0:
+ dp = self.unbatch(unbatch_level)._dataframes_per_row()
+ else:
+ dp = self._dataframes_per_row()
+ dp = dp.as_datapipe().groupby(group_key_fn, buffer_size=buffer_size, group_size=group_size,
+ guaranteed_group_size=guaranteed_group_size, drop_remaining=drop_remaining)
+ return dp
+
+ def shuffle(self, *args, **kwargs):
+ return self._dataframes_shuffle(*args, **kwargs)
+
+ def filter(self, *args, **kwargs):
+ return self._dataframes_filter(*args, **kwargs)
+
+
+class CaptureF(Capture):
+ def __init__(self, ctx=None, **kwargs):
+ if ctx is None:
+ self.ctx = {'operations': [], 'variables': []}
+ self.ctx = ctx
+ self.kwargs = kwargs
+
+
+class CaptureCall(CaptureF):
+ def __str__(self):
+ return "{variable}({args},{kwargs})".format(**self.kwargs)
+
+ def execute(self):
+ return (get_val(self.kwargs['variable']))(*self.kwargs['args'], **self.kwargs['kwargs'])
+
+
+class CaptureVariableAssign(CaptureF):
+ def __str__(self):
+ return "{variable} = {value}".format(**self.kwargs)
+
+ def execute(self):
+ self.kwargs['variable'].calculated_value = self.kwargs['value'].execute()
+
+
+class CaptureVariable(Capture):
+ value = None
+ name = None
+ calculated_value = None
+ names_idx = 0
+
+ def __init__(self, value, ctx):
+ self.ctx = ctx
+ self.value = value
+ self.name = 'var_%s' % CaptureVariable.names_idx
+ CaptureVariable.names_idx += 1
+ self.ctx['variables'].append(self)
+
+ def __str__(self):
+ return self.name
+
+ def execute(self):
+ return self.calculated_value
+
+ def calculate_me(self, dataframe):
+ self.ctx['variables'][0].calculated_value = dataframe
+ for op in self.ctx['operations']:
+ op.execute()
+ return self.calculated_value
+
+
+class CaptureInitial(CaptureVariable):
+
+ def __init__(self):
+ new_ctx: Dict[str, List[Any]] = {'operations': [], 'variables': []}
+ super().__init__(None, new_ctx)
+ self.name = 'input_%s' % self.name
+
+
+class CaptureGetItem(Capture):
+ left : Capture
+ key : Any
+
+ def __init__(self, left, key, ctx):
+ self.ctx = ctx
+ self.left = left
+ self.key = key
+
+ def __str__(self):
+ return "%s[%s]" % (self.left, get_val(self.key))
+
+ def execute(self):
+ return (self.left.execute())[self.key]
+
+
+class CaptureSetItem(Capture):
+ left : Capture
+ key : Any
+ value : Capture
+
+ def __init__(self, left, key, value, ctx):
+ self.ctx = ctx
+ self.left = left
+ self.key = key
+ self.value = value
+
+ def __str__(self):
+ return "%s[%s] = %s" % (self.left, get_val(self.key), self.value)
+
+ def execute(self):
+ (self.left.execute())[
+ self.key] = self.value.execute()
+
+
+class CaptureAdd(Capture):
+ left = None
+ right = None
+
+ def __init__(self, left, right, ctx):
+ self.ctx = ctx
+ self.left = left
+ self.right = right
+
+ def __str__(self):
+ return "%s + %s" % (self.left, self.right)
+
+ def execute(self):
+ return get_val(self.left) + get_val(self.right)
+
+
+class CaptureMul(Capture):
+ left = None
+ right = None
+
+ def __init__(self, left, right, ctx):
+ self.ctx = ctx
+ self.left = left
+ self.right = right
+
+ def __str__(self):
+ return "%s * %s" % (self.left, self.right)
+
+ def execute(self):
+ return get_val(self.left) * get_val(self.right)
+
+
+class CaptureSub(Capture):
+ left = None
+ right = None
+
+ def __init__(self, left, right, ctx):
+ self.ctx = ctx
+ self.left = left
+ self.right = right
+
+ def __str__(self):
+ return "%s - %s" % (self.left, self.right)
+
+ def execute(self):
+ return get_val(self.left) - get_val(self.right)
+
+
+class CaptureGetAttr(Capture):
+ source = None
+ name: str
+
+ def __init__(self, src, name, ctx):
+ self.ctx = ctx
+ self.src = src
+ self.name = name
+
+ def __str__(self):
+ return "%s.%s" % (self.src, self.name)
+
+ def execute(self):
+ val = get_val(self.src)
+ return getattr(val, self.name)
+
+
+def get_val(capture):
+ if isinstance(capture, Capture):
+ return capture.execute()
+ elif isinstance(capture, str):
+ return '"%s"' % capture
+ else:
+ return capture
+
+
+@functional_datapipe('trace_as_dataframe')
+class DataFrameTracer(CaptureInitial, IterDataPipe):
+ source_datapipe = None
+
+ def __init__(self, source_datapipe):
+ super().__init__()
+ self.source_datapipe = source_datapipe
diff --git a/torch/utils/data/datapipes/dataframe/datapipes.py b/torch/utils/data/datapipes/dataframe/datapipes.py
new file mode 100644
index 0000000..f76189a
--- /dev/null
+++ b/torch/utils/data/datapipes/dataframe/datapipes.py
@@ -0,0 +1,137 @@
+import random
+
+from torch.utils.data import (
+ DFIterDataPipe,
+ IterDataPipe,
+ functional_datapipe,
+)
+
+try:
+ import pandas # type: ignore[import]
+ # pandas used only for prototyping, will be shortly replaced with TorchArrow
+ WITH_PANDAS = True
+except ImportError:
+ WITH_PANDAS = False
+
+
+@functional_datapipe('_dataframes_as_tuples')
+class DataFramesAsTuplesPipe(IterDataPipe):
+ def __init__(self, source_datapipe):
+ self.source_datapipe = source_datapipe
+
+ def __iter__(self):
+ for df in self.source_datapipe:
+ for record in df.to_records(index=False):
+ yield record
+
+
+@functional_datapipe('_dataframes_per_row', enable_df_api_tracing=True)
+class PerRowDataFramesPipe(DFIterDataPipe):
+ def __init__(self, source_datapipe):
+ self.source_datapipe = source_datapipe
+
+ def __iter__(self):
+ for df in self.source_datapipe:
+ for i in range(len(df.index)):
+ yield df[i:i + 1]
+
+
+@functional_datapipe('_dataframes_concat', enable_df_api_tracing=True)
+class ConcatDataFramesPipe(DFIterDataPipe):
+ def __init__(self, source_datapipe, batch=3):
+ self.source_datapipe = source_datapipe
+ self.batch = batch
+ if not WITH_PANDAS:
+ Exception('DataFrames prototype requires pandas to function')
+
+ def __iter__(self):
+ buffer = []
+ for df in self.source_datapipe:
+ buffer.append(df)
+ if len(buffer) == self.batch:
+ yield pandas.concat(buffer)
+ buffer = []
+ if len(buffer):
+ yield pandas.concat(buffer)
+
+
+@functional_datapipe('_dataframes_shuffle', enable_df_api_tracing=True)
+class ShuffleDataFramesPipe(DFIterDataPipe):
+ def __init__(self, source_datapipe):
+ self.source_datapipe = source_datapipe
+ if not WITH_PANDAS:
+ Exception('DataFrames prototype requires pandas to function')
+
+ def __iter__(self):
+ size = None
+ all_buffer = []
+ for df in self.source_datapipe:
+ if size is None:
+ size = len(df.index)
+ for i in range(len(df.index)):
+ all_buffer.append(df[i:i + 1])
+ random.shuffle(all_buffer)
+ buffer = []
+ for df in all_buffer:
+ buffer.append(df)
+ if len(buffer) == size:
+ yield pandas.concat(buffer)
+ buffer = []
+ if len(buffer):
+ yield pandas.concat(buffer)
+
+
+@functional_datapipe('_dataframes_filter', enable_df_api_tracing=True)
+class FilterDataFramesPipe(DFIterDataPipe):
+ def __init__(self, source_datapipe, filter_fn):
+ self.source_datapipe = source_datapipe
+ self.filter_fn = filter_fn
+ if not WITH_PANDAS:
+ Exception('DataFrames prototype requires pandas to function')
+
+ def __iter__(self):
+ size = None
+ all_buffer = []
+ filter_res = []
+ for df in self.source_datapipe:
+ if size is None:
+ size = len(df.index)
+ for i in range(len(df.index)):
+ all_buffer.append(df[i:i + 1])
+ filter_res.append(self.filter_fn(df.iloc[i]))
+
+ buffer = []
+ for df, res in zip(all_buffer, filter_res):
+ if res:
+ buffer.append(df)
+ if len(buffer) == size:
+ yield pandas.concat(buffer)
+ buffer = []
+ if len(buffer):
+ yield pandas.concat(buffer)
+
+
+@functional_datapipe('_to_dataframes_pipe', enable_df_api_tracing=True)
+class ExampleAggregateAsDataFrames(DFIterDataPipe):
+ def __init__(self, source_datapipe, dataframe_size=10, columns=None):
+ self.source_datapipe = source_datapipe
+ self.columns = columns
+ self.dataframe_size = dataframe_size
+ if not WITH_PANDAS:
+ Exception('DataFrames prototype requires pandas to function')
+
+ def _as_list(self, item):
+ try:
+ return list(item)
+ except Exception: # TODO(VitalyFedyunin): Replace with better iterable exception
+ return [item]
+
+ def __iter__(self):
+ aggregate = []
+ for item in self.source_datapipe:
+ aggregate.append(self._as_list(item))
+ if len(aggregate) == self.dataframe_size:
+ yield pandas.DataFrame(aggregate, columns=self.columns)
+ aggregate = []
+ if len(aggregate) > 0:
+ yield pandas.DataFrame(aggregate, columns=self.columns)
diff --git a/torch/utils/data/datapipes/dataframe/structures.py b/torch/utils/data/datapipes/dataframe/structures.py
new file mode 100644
index 0000000..c822f89
--- /dev/null
+++ b/torch/utils/data/datapipes/dataframe/structures.py
@@ -0,0 +1,20 @@
+from torch.utils.data import (
+ DataChunk,
+)
+
+class DataChunkDF(DataChunk):
+ """
+ DataChunkDF iterating over individual items inside of DataFrame containers,
+ to access DataFrames user `raw_iterator`
+ """
+
+ def __iter__(self):
+ for df in self.items:
+ for record in df.to_records(index=False):
+ yield record
+
+ def __len__(self):
+ total_len = 0
+ for df in self.items:
+ total_len += len(df)
+ return total_len
diff --git a/torch/utils/data/datapipes/iter/__init__.py b/torch/utils/data/datapipes/iter/__init__.py
index 26d715d..0b96cfd 100644
--- a/torch/utils/data/datapipes/iter/__init__.py
+++ b/torch/utils/data/datapipes/iter/__init__.py
@@ -54,6 +54,7 @@
'BucketBatcher',
'Collator',
'Concater',
+ 'DFIterDataPipe',
'Demultiplexer',
'FileLister',
'FileLoader',
diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py
index d90ad08..f848681 100644
--- a/torch/utils/data/datapipes/iter/grouping.py
+++ b/torch/utils/data/datapipes/iter/grouping.py
@@ -60,6 +60,7 @@
batch_size: int,
drop_last: bool = False,
unbatch_level: int = 0,
+ wrapper_class=DataChunk,
) -> None:
assert batch_size > 0, "Batch size is required to be larger than 0!"
super().__init__()
@@ -71,7 +72,7 @@
self.batch_size = batch_size
self.drop_last = drop_last
self.length = None
- self.wrapper_class = DataChunk
+ self.wrapper_class = wrapper_class
def __iter__(self) -> Iterator[DataChunk]:
batch: List = []
diff --git a/torch/utils/data/datapipes/iter/selecting.py b/torch/utils/data/datapipes/iter/selecting.py
index f1889e5..2f0a158 100644
--- a/torch/utils/data/datapipes/iter/selecting.py
+++ b/torch/utils/data/datapipes/iter/selecting.py
@@ -1,6 +1,14 @@
import warnings
-from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk
-from typing import Callable, TypeVar, Iterator, Optional, Tuple, Dict
+from typing import Callable, Dict, Iterator, Optional, Tuple, TypeVar
+
+from torch.utils.data import DataChunk, IterDataPipe, functional_datapipe
+
+try:
+ import pandas # type: ignore[import]
+ # pandas used only for prototyping, will be shortly replaced with TorchArrow
+ WITH_PANDAS = True
+except ImportError:
+ WITH_PANDAS = False
T_co = TypeVar('T_co', covariant=True)
@@ -91,12 +99,27 @@
def _returnIfTrue(self, data):
condition = self.filter_fn(data, *self.args, **self.kwargs)
+ if WITH_PANDAS:
+ if isinstance(condition, pandas.core.series.Series):
+ # We are operatring on DataFrames filter here
+ result = []
+ for idx, mask in enumerate(condition):
+ if mask:
+ result.append(data[idx:idx + 1])
+ if len(result):
+ return pandas.concat(result)
+ else:
+ return None
+
if not isinstance(condition, bool):
- raise ValueError("Boolean output is required for `filter_fn` of FilterIterDataPipe")
+ raise ValueError("Boolean output is required for `filter_fn` of FilterIterDataPipe, got", type(condition))
if condition:
return data
def _isNonEmpty(self, data):
+ if WITH_PANDAS:
+ if isinstance(data, pandas.core.frame.DataFrame):
+ return True
r = data is not None and \
not (isinstance(data, list) and len(data) == 0 and self.drop_empty_batches)
return r
diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py
index 50488d1..c968fdf 100644
--- a/torch/utils/data/dataset.py
+++ b/torch/utils/data/dataset.py
@@ -24,6 +24,11 @@
T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
+UNTRACABLE_DATAFRAME_PIPES = ['batch', # As it returns DataChunks
+ 'groupby', # As it returns DataChunks
+ '_dataframes_as_tuples', # As it unpacks DF
+ 'trace_as_dataframe', # As it used to mark DF for tracing
+ ]
class DataChunk(list, Generic[T]):
def __init__(self, items):
@@ -82,13 +87,20 @@
cls.functions[function_name] = function
@classmethod
- def register_datapipe_as_function(cls, function_name, cls_to_register):
+ def register_datapipe_as_function(cls, function_name, cls_to_register, enable_df_api_tracing=False):
if function_name in cls.functions:
raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name))
- def class_function(cls, source_dp, *args, **kwargs):
- return cls(source_dp, *args, **kwargs)
- function = functools.partial(class_function, cls_to_register)
+ def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
+ result_pipe = cls(source_dp, *args, **kwargs)
+ if isinstance(result_pipe, Dataset):
+ if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe):
+ if function_name not in UNTRACABLE_DATAFRAME_PIPES:
+ result_pipe = result_pipe.trace_as_dataframe()
+
+ return result_pipe
+
+ function = functools.partial(class_function, cls_to_register, enable_df_api_tracing)
cls.functions[function_name] = function
@@ -227,6 +239,9 @@
raise Exception("Attempt to override existing reduce_ex_hook")
IterableDataset.reduce_ex_hook = hook_fn
+class DFIterDataPipe(IterableDataset):
+ def _is_dfpipe(self):
+ return True
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
r"""Dataset wrapping tensors.