Added PandasDataframeWrapper (#65411)
Summary:
- Added `PandasDataframeWrapper` around `pandas` functions to easily drop-and-replace`torcharrow` for Facebook internal use cases
- Updated relevant datapipe/dataframe usesites to use the new `PandasDataframeWrapper` instead of calling `pandas` functions directly
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65411
Reviewed By: VitalyFedyunin, hudeven
Differential Revision: D31087746
Pulled By: Nayef211
fbshipit-source-id: 299901f93a967a5fb8ed99d6db9b8b9203634b8f
diff --git a/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py b/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py
new file mode 100644
index 0000000..1c61dcb
--- /dev/null
+++ b/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py
@@ -0,0 +1,99 @@
+try:
+ import pandas # type: ignore[import]
+
+ # pandas used only for prototyping, will be shortly replaced with TorchArrow
+ WITH_PANDAS = True
+except ImportError:
+ WITH_PANDAS = False
+
+
+class PandasWrapper:
+ @classmethod
+ def create_dataframe(cls, data, columns):
+ if not WITH_PANDAS:
+ raise Exception("DataFrames prototype requires pandas to function")
+ return pandas.DataFrame(data, columns=columns)
+
+ @classmethod
+ def is_dataframe(cls, data):
+ if not WITH_PANDAS:
+ return False
+ return isinstance(data, pandas.core.frame.DataFrame)
+
+ @classmethod
+ def is_column(cls, data):
+ if not WITH_PANDAS:
+ return False
+ return isinstance(data, pandas.core.series.Series)
+
+ @classmethod
+ def iterate(cls, data):
+ if not WITH_PANDAS:
+ raise Exception("DataFrames prototype requires pandas to function")
+ for d in data:
+ yield d
+
+ @classmethod
+ def concat(cls, buffer):
+ if not WITH_PANDAS:
+ raise Exception("DataFrames prototype requires pandas to function")
+ return pandas.concat(buffer)
+
+ @classmethod
+ def get_item(cls, data, idx):
+ if not WITH_PANDAS:
+ raise Exception("DataFrames prototype requires pandas to function")
+ return data[idx : idx + 1]
+
+ @classmethod
+ def get_len(cls, df):
+ if not WITH_PANDAS:
+ raise Exception("DataFrames prototype requires pandas to function")
+ return len(df.index)
+
+
+# When you build own implementation just override it with dataframe_wrapper.set_df_wrapper(new_wrapper_class)
+default_wrapper = PandasWrapper
+
+
+def get_df_wrapper():
+ return default_wrapper
+
+
+def set_df_wrapper(wrapper):
+ default_wrapper = wrapper
+
+
+def create_dataframe(data, columns=None):
+ wrapper = get_df_wrapper()
+ wrapper.create_dataframe(data, columns)
+
+
+def is_dataframe(data):
+ wrapper = get_df_wrapper()
+ wrapper.is_dataframe(data)
+
+
+def is_column(data):
+ wrapper = get_df_wrapper()
+ wrapper.is_column(data)
+
+
+def concat(buffer):
+ wrapper = get_df_wrapper()
+ wrapper.concat(buffer)
+
+
+def iterate(data):
+ wrapper = get_df_wrapper()
+ wrapper.iterate(data)
+
+
+def get_item(data, idx):
+ wrapper = get_df_wrapper()
+ wrapper.get_item(data, idx)
+
+
+def get_len(df):
+ wrapper = get_df_wrapper()
+ wrapper.get_len(df)
diff --git a/torch/utils/data/datapipes/dataframe/datapipes.py b/torch/utils/data/datapipes/dataframe/datapipes.py
index f76189a..43e15c3 100644
--- a/torch/utils/data/datapipes/dataframe/datapipes.py
+++ b/torch/utils/data/datapipes/dataframe/datapipes.py
@@ -5,14 +5,7 @@
IterDataPipe,
functional_datapipe,
)
-
-try:
- import pandas # type: ignore[import]
- # pandas used only for prototyping, will be shortly replaced with TorchArrow
- WITH_PANDAS = True
-except ImportError:
- WITH_PANDAS = False
-
+from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
@functional_datapipe('_dataframes_as_tuples')
class DataFramesAsTuplesPipe(IterDataPipe):
@@ -41,44 +34,40 @@
def __init__(self, source_datapipe, batch=3):
self.source_datapipe = source_datapipe
self.batch = batch
- if not WITH_PANDAS:
- Exception('DataFrames prototype requires pandas to function')
def __iter__(self):
buffer = []
for df in self.source_datapipe:
buffer.append(df)
if len(buffer) == self.batch:
- yield pandas.concat(buffer)
+ yield df_wrapper.concat(buffer)
buffer = []
if len(buffer):
- yield pandas.concat(buffer)
+ yield df_wrapper.concat(buffer)
@functional_datapipe('_dataframes_shuffle', enable_df_api_tracing=True)
class ShuffleDataFramesPipe(DFIterDataPipe):
def __init__(self, source_datapipe):
self.source_datapipe = source_datapipe
- if not WITH_PANDAS:
- Exception('DataFrames prototype requires pandas to function')
def __iter__(self):
size = None
all_buffer = []
for df in self.source_datapipe:
if size is None:
- size = len(df.index)
- for i in range(len(df.index)):
- all_buffer.append(df[i:i + 1])
+ size = df_wrapper.get_len(df)
+ for i in range(df_wrapper.get_len(df)):
+ all_buffer.append(df_wrapper.get_item(df, i))
random.shuffle(all_buffer)
buffer = []
for df in all_buffer:
buffer.append(df)
if len(buffer) == size:
- yield pandas.concat(buffer)
+ yield df_wrapper.concat(buffer)
buffer = []
if len(buffer):
- yield pandas.concat(buffer)
+ yield df_wrapper.concat(buffer)
@functional_datapipe('_dataframes_filter', enable_df_api_tracing=True)
@@ -86,8 +75,6 @@
def __init__(self, source_datapipe, filter_fn):
self.source_datapipe = source_datapipe
self.filter_fn = filter_fn
- if not WITH_PANDAS:
- Exception('DataFrames prototype requires pandas to function')
def __iter__(self):
size = None
@@ -105,10 +92,10 @@
if res:
buffer.append(df)
if len(buffer) == size:
- yield pandas.concat(buffer)
+ yield df_wrapper.concat(buffer)
buffer = []
if len(buffer):
- yield pandas.concat(buffer)
+ yield df_wrapper.concat(buffer)
@functional_datapipe('_to_dataframes_pipe', enable_df_api_tracing=True)
@@ -117,8 +104,6 @@
self.source_datapipe = source_datapipe
self.columns = columns
self.dataframe_size = dataframe_size
- if not WITH_PANDAS:
- Exception('DataFrames prototype requires pandas to function')
def _as_list(self, item):
try:
@@ -131,7 +116,7 @@
for item in self.source_datapipe:
aggregate.append(self._as_list(item))
if len(aggregate) == self.dataframe_size:
- yield pandas.DataFrame(aggregate, columns=self.columns)
+ yield df_wrapper.create_dataframe(aggregate, columns=self.columns)
aggregate = []
if len(aggregate) > 0:
- yield pandas.DataFrame(aggregate, columns=self.columns)
+ yield df_wrapper.create_dataframe(aggregate, columns=self.columns)
diff --git a/torch/utils/data/datapipes/iter/selecting.py b/torch/utils/data/datapipes/iter/selecting.py
index 2f0a158..aa7e699 100644
--- a/torch/utils/data/datapipes/iter/selecting.py
+++ b/torch/utils/data/datapipes/iter/selecting.py
@@ -2,13 +2,7 @@
from typing import Callable, Dict, Iterator, Optional, Tuple, TypeVar
from torch.utils.data import DataChunk, IterDataPipe, functional_datapipe
-
-try:
- import pandas # type: ignore[import]
- # pandas used only for prototyping, will be shortly replaced with TorchArrow
- WITH_PANDAS = True
-except ImportError:
- WITH_PANDAS = False
+from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
T_co = TypeVar('T_co', covariant=True)
@@ -99,17 +93,17 @@
def _returnIfTrue(self, data):
condition = self.filter_fn(data, *self.args, **self.kwargs)
- if WITH_PANDAS:
- if isinstance(condition, pandas.core.series.Series):
- # We are operatring on DataFrames filter here
- result = []
- for idx, mask in enumerate(condition):
- if mask:
- result.append(data[idx:idx + 1])
- if len(result):
- return pandas.concat(result)
- else:
- return None
+
+ if df_wrapper.is_column(condition):
+ # We are operatring on DataFrames filter here
+ result = []
+ for idx, mask in enumerate(df_wrapper.iterate(condition)):
+ if mask:
+ result.append(df_wrapper.get_item(data, idx))
+ if len(result):
+ return df_wrapper.concat(result)
+ else:
+ return None
if not isinstance(condition, bool):
raise ValueError("Boolean output is required for `filter_fn` of FilterIterDataPipe, got", type(condition))
@@ -117,9 +111,8 @@
return data
def _isNonEmpty(self, data):
- if WITH_PANDAS:
- if isinstance(data, pandas.core.frame.DataFrame):
- return True
+ if df_wrapper.is_dataframe(data):
+ return True
r = data is not None and \
not (isinstance(data, list) and len(data) == 0 and self.drop_empty_batches)
return r