Fix validate_input_col for nn.Module or Callable (#96213)
Forward fix the problem introduced in https://github.com/pytorch/pytorch/pull/95067
Not all `Callable` objects have `__name__` implemented. Using `repr` as the backup solution to get function name or reference.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96213
Approved by: https://github.com/NivekT
diff --git a/test/test_datapipe.py b/test/test_datapipe.py
index bc03000..a77d6ad 100644
--- a/test/test_datapipe.py
+++ b/test/test_datapipe.py
@@ -30,6 +30,7 @@
import numpy as np
import torch
+import torch.nn as nn
import torch.utils.data.datapipes as dp
import torch.utils.data.graph
import torch.utils.data.graph_settings
@@ -663,6 +664,16 @@
lambda_fn3 = lambda x: x >= 5 # noqa: E731
+class Add1Module(nn.Module):
+ def forward(self, x):
+ return x + 1
+
+
+class Add1Callable:
+ def __call__(self, x):
+ return x + 1
+
+
class TestFunctionalIterDataPipe(TestCase):
def _serialization_test_helper(self, datapipe, use_dill):
@@ -1326,6 +1337,10 @@
_helper(lambda data: (str(data[0]), data[1], data[2]), str, 0)
_helper(lambda data: (data[0], data[1], int(data[2])), int, 2)
+ # Handle nn.Module and Callable (without __name__ implemented)
+ _helper(lambda data: (data[0] + 1, data[1], data[2]), Add1Module(), 0)
+ _helper(lambda data: (data[0] + 1, data[1], data[2]), Add1Callable(), 0)
+
@suppress_warnings # Suppress warning for lambda fn
def test_map_dict_with_col_iterdatapipe(self):
def fn_11(d):
diff --git a/torch/utils/data/datapipes/utils/common.py b/torch/utils/data/datapipes/utils/common.py
index 0599da7..e39d67e 100644
--- a/torch/utils/data/datapipes/utils/common.py
+++ b/torch/utils/data/datapipes/utils/common.py
@@ -78,9 +78,9 @@
continue
if isinstance(fn, functools.partial):
- fn_name = fn.func.__name__
+ fn_name = getattr(fn.func, "__name__", repr(fn.func))
else:
- fn_name = fn.__name__
+ fn_name = getattr(fn, "__name__", repr(fn))
if len(non_default_kw_only) > 0:
raise ValueError(