[PyTorch] Move NestedTensor printing to _tensor_str.py (#74000)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74000
Now that we're in-core, we can just customize this.
ghstack-source-id: 151540966
Test Plan: Existing test_nestedtensor seems to pass
Reviewed By: ezyang
Differential Revision: D34665270
fbshipit-source-id: 5097944a4dc4fe80cea2b8576f0123466dbeab43
(cherry picked from commit d0315f46f9906c904639f43f218e439407f5b2a7)
diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py
index b0bb6e9..1c97505 100644
--- a/torch/_tensor_str.py
+++ b/torch/_tensor_str.py
@@ -298,14 +298,14 @@
return torch.stack([get_summarized_data(x) for x in self])
def _str_intern(inp):
- prefix = 'tensor('
+ self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
+ prefix = "nested_tensor(" if self.is_nested else 'tensor('
indent = len(prefix)
suffixes = []
# This is used to extract the primal value and thus disable the forward AD
# within this function.
# TODO(albanD) This needs to be updated when more than one level is supported
- self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
# Note [Print tensor device]:
# A general logic here is we only print device when it doesn't match
@@ -380,6 +380,11 @@
suffixes.append('zero_point=' + str(self.q_per_channel_zero_points()))
suffixes.append('axis=' + str(self.q_per_channel_axis()))
tensor_str = _tensor_str(self.dequantize(), indent)
+ elif self.is_nested:
+ def indented_str(s, indent):
+ return "\n".join(f" {line}" for line in s.split("\n"))
+ strs = ",\n".join(indented_str(str(t), indent + 1) for t in torch.ops.aten.unbind.int(self, 0))
+ tensor_str = f"[\n{strs}\n]"
else:
if self.is_meta:
suffixes.append('size=' + str(tuple(self.shape)))
diff --git a/torch/nested/_nestedtensor.py b/torch/nested/_nestedtensor.py
index 4f16a31..6be365a 100644
--- a/torch/nested/_nestedtensor.py
+++ b/torch/nested/_nestedtensor.py
@@ -77,22 +77,7 @@
return self._impl.is_contiguous()
def __str__(self):
- def _str(x, indent=0, tab=" "):
- s = indent * tab + "[\n"
- strs = list(map(str, x.unbind()))
- strs = list(
- map(
- lambda xi: "\n".join(
- map(lambda xij: (indent + 1) * tab + xij, xi.split("\n"))
- ),
- strs,
- )
- )
- s += ",\n".join(strs)
- s += "\n" + indent * tab + "]"
- return s
-
- return "nested_tensor(" + _str(self) + ")"
+ return str(self._impl)
def __repr__(self):
return self.__str__()