Added PandasDataframeWrapper (#65411)

Summary:
- Added `PandasDataframeWrapper` around `pandas` functions to easily drop-and-replace`torcharrow` for Facebook internal use cases
- Updated relevant datapipe/dataframe usesites to use the new `PandasDataframeWrapper` instead of calling `pandas` functions directly

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65411

Reviewed By: VitalyFedyunin, hudeven

Differential Revision: D31087746

Pulled By: Nayef211

fbshipit-source-id: 299901f93a967a5fb8ed99d6db9b8b9203634b8f
diff --git a/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py b/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py
new file mode 100644
index 0000000..1c61dcb
--- /dev/null
+++ b/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py
@@ -0,0 +1,99 @@
+try:
+    import pandas  # type: ignore[import]
+
+    # pandas used only for prototyping, will be shortly replaced with TorchArrow
+    WITH_PANDAS = True
+except ImportError:
+    WITH_PANDAS = False
+
+
+class PandasWrapper:
+    @classmethod
+    def create_dataframe(cls, data, columns):
+        if not WITH_PANDAS:
+            raise Exception("DataFrames prototype requires pandas to function")
+        return pandas.DataFrame(data, columns=columns)
+
+    @classmethod
+    def is_dataframe(cls, data):
+        if not WITH_PANDAS:
+            return False
+        return isinstance(data, pandas.core.frame.DataFrame)
+
+    @classmethod
+    def is_column(cls, data):
+        if not WITH_PANDAS:
+            return False
+        return isinstance(data, pandas.core.series.Series)
+
+    @classmethod
+    def iterate(cls, data):
+        if not WITH_PANDAS:
+            raise Exception("DataFrames prototype requires pandas to function")
+        for d in data:
+            yield d
+
+    @classmethod
+    def concat(cls, buffer):
+        if not WITH_PANDAS:
+            raise Exception("DataFrames prototype requires pandas to function")
+        return pandas.concat(buffer)
+
+    @classmethod
+    def get_item(cls, data, idx):
+        if not WITH_PANDAS:
+            raise Exception("DataFrames prototype requires pandas to function")
+        return data[idx : idx + 1]
+
+    @classmethod
+    def get_len(cls, df):
+        if not WITH_PANDAS:
+            raise Exception("DataFrames prototype requires pandas to function")
+        return len(df.index)
+
+
+# When you build own implementation just override it with dataframe_wrapper.set_df_wrapper(new_wrapper_class)
+default_wrapper = PandasWrapper
+
+
+def get_df_wrapper():
+    return default_wrapper
+
+
+def set_df_wrapper(wrapper):
+    default_wrapper = wrapper
+
+
+def create_dataframe(data, columns=None):
+    wrapper = get_df_wrapper()
+    wrapper.create_dataframe(data, columns)
+
+
+def is_dataframe(data):
+    wrapper = get_df_wrapper()
+    wrapper.is_dataframe(data)
+
+
+def is_column(data):
+    wrapper = get_df_wrapper()
+    wrapper.is_column(data)
+
+
+def concat(buffer):
+    wrapper = get_df_wrapper()
+    wrapper.concat(buffer)
+
+
+def iterate(data):
+    wrapper = get_df_wrapper()
+    wrapper.iterate(data)
+
+
+def get_item(data, idx):
+    wrapper = get_df_wrapper()
+    wrapper.get_item(data, idx)
+
+
+def get_len(df):
+    wrapper = get_df_wrapper()
+    wrapper.get_len(df)
diff --git a/torch/utils/data/datapipes/dataframe/datapipes.py b/torch/utils/data/datapipes/dataframe/datapipes.py
index f76189a..43e15c3 100644
--- a/torch/utils/data/datapipes/dataframe/datapipes.py
+++ b/torch/utils/data/datapipes/dataframe/datapipes.py
@@ -5,14 +5,7 @@
     IterDataPipe,
     functional_datapipe,
 )
-
-try:
-    import pandas  # type: ignore[import]
-    # pandas used only for prototyping, will be shortly replaced with TorchArrow
-    WITH_PANDAS = True
-except ImportError:
-    WITH_PANDAS = False
-
+from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
 
 @functional_datapipe('_dataframes_as_tuples')
 class DataFramesAsTuplesPipe(IterDataPipe):
@@ -41,44 +34,40 @@
     def __init__(self, source_datapipe, batch=3):
         self.source_datapipe = source_datapipe
         self.batch = batch
-        if not WITH_PANDAS:
-            Exception('DataFrames prototype requires pandas to function')
 
     def __iter__(self):
         buffer = []
         for df in self.source_datapipe:
             buffer.append(df)
             if len(buffer) == self.batch:
-                yield pandas.concat(buffer)
+                yield df_wrapper.concat(buffer)
                 buffer = []
         if len(buffer):
-            yield pandas.concat(buffer)
+            yield df_wrapper.concat(buffer)
 
 
 @functional_datapipe('_dataframes_shuffle', enable_df_api_tracing=True)
 class ShuffleDataFramesPipe(DFIterDataPipe):
     def __init__(self, source_datapipe):
         self.source_datapipe = source_datapipe
-        if not WITH_PANDAS:
-            Exception('DataFrames prototype requires pandas to function')
 
     def __iter__(self):
         size = None
         all_buffer = []
         for df in self.source_datapipe:
             if size is None:
-                size = len(df.index)
-            for i in range(len(df.index)):
-                all_buffer.append(df[i:i + 1])
+                size = df_wrapper.get_len(df)
+            for i in range(df_wrapper.get_len(df)):
+                all_buffer.append(df_wrapper.get_item(df, i))
         random.shuffle(all_buffer)
         buffer = []
         for df in all_buffer:
             buffer.append(df)
             if len(buffer) == size:
-                yield pandas.concat(buffer)
+                yield df_wrapper.concat(buffer)
                 buffer = []
         if len(buffer):
-            yield pandas.concat(buffer)
+            yield df_wrapper.concat(buffer)
 
 
 @functional_datapipe('_dataframes_filter', enable_df_api_tracing=True)
@@ -86,8 +75,6 @@
     def __init__(self, source_datapipe, filter_fn):
         self.source_datapipe = source_datapipe
         self.filter_fn = filter_fn
-        if not WITH_PANDAS:
-            Exception('DataFrames prototype requires pandas to function')
 
     def __iter__(self):
         size = None
@@ -105,10 +92,10 @@
             if res:
                 buffer.append(df)
                 if len(buffer) == size:
-                    yield pandas.concat(buffer)
+                    yield df_wrapper.concat(buffer)
                     buffer = []
         if len(buffer):
-            yield pandas.concat(buffer)
+            yield df_wrapper.concat(buffer)
 
 
 @functional_datapipe('_to_dataframes_pipe', enable_df_api_tracing=True)
@@ -117,8 +104,6 @@
         self.source_datapipe = source_datapipe
         self.columns = columns
         self.dataframe_size = dataframe_size
-        if not WITH_PANDAS:
-            Exception('DataFrames prototype requires pandas to function')
 
     def _as_list(self, item):
         try:
@@ -131,7 +116,7 @@
         for item in self.source_datapipe:
             aggregate.append(self._as_list(item))
             if len(aggregate) == self.dataframe_size:
-                yield pandas.DataFrame(aggregate, columns=self.columns)
+                yield df_wrapper.create_dataframe(aggregate, columns=self.columns)
                 aggregate = []
         if len(aggregate) > 0:
-            yield pandas.DataFrame(aggregate, columns=self.columns)
+            yield df_wrapper.create_dataframe(aggregate, columns=self.columns)
diff --git a/torch/utils/data/datapipes/iter/selecting.py b/torch/utils/data/datapipes/iter/selecting.py
index 2f0a158..aa7e699 100644
--- a/torch/utils/data/datapipes/iter/selecting.py
+++ b/torch/utils/data/datapipes/iter/selecting.py
@@ -2,13 +2,7 @@
 from typing import Callable, Dict, Iterator, Optional, Tuple, TypeVar
 
 from torch.utils.data import DataChunk, IterDataPipe, functional_datapipe
-
-try:
-    import pandas  # type: ignore[import]
-    # pandas used only for prototyping, will be shortly replaced with TorchArrow
-    WITH_PANDAS = True
-except ImportError:
-    WITH_PANDAS = False
+from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
 
 T_co = TypeVar('T_co', covariant=True)
 
@@ -99,17 +93,17 @@
 
     def _returnIfTrue(self, data):
         condition = self.filter_fn(data, *self.args, **self.kwargs)
-        if WITH_PANDAS:
-            if isinstance(condition, pandas.core.series.Series):
-                # We are operatring on DataFrames filter here
-                result = []
-                for idx, mask in enumerate(condition):
-                    if mask:
-                        result.append(data[idx:idx + 1])
-                if len(result):
-                    return pandas.concat(result)
-                else:
-                    return None
+
+        if df_wrapper.is_column(condition):
+            # We are operatring on DataFrames filter here
+            result = []
+            for idx, mask in enumerate(df_wrapper.iterate(condition)):
+                if mask:
+                    result.append(df_wrapper.get_item(data, idx))
+            if len(result):
+                return df_wrapper.concat(result)
+            else:
+                return None
 
         if not isinstance(condition, bool):
             raise ValueError("Boolean output is required for `filter_fn` of FilterIterDataPipe, got", type(condition))
@@ -117,9 +111,8 @@
             return data
 
     def _isNonEmpty(self, data):
-        if WITH_PANDAS:
-            if isinstance(data, pandas.core.frame.DataFrame):
-                return True
+        if df_wrapper.is_dataframe(data):
+            return True
         r = data is not None and \
             not (isinstance(data, list) and len(data) == 0 and self.drop_empty_batches)
         return r