[PyTorch] Move NestedTensor printing to _tensor_str.py (#74000)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74000

Now that we're in-core, we can just customize this.
ghstack-source-id: 151540966

Test Plan: Existing test_nestedtensor seems to pass

Reviewed By: ezyang

Differential Revision: D34665270

fbshipit-source-id: 5097944a4dc4fe80cea2b8576f0123466dbeab43
(cherry picked from commit d0315f46f9906c904639f43f218e439407f5b2a7)
diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py
index b0bb6e9..1c97505 100644
--- a/torch/_tensor_str.py
+++ b/torch/_tensor_str.py
@@ -298,14 +298,14 @@
         return torch.stack([get_summarized_data(x) for x in self])
 
 def _str_intern(inp):
-    prefix = 'tensor('
+    self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
+    prefix = "nested_tensor(" if self.is_nested else 'tensor('
     indent = len(prefix)
     suffixes = []
 
     # This is used to extract the primal value and thus disable the forward AD
     # within this function.
     # TODO(albanD) This needs to be updated when more than one level is supported
-    self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
 
     # Note [Print tensor device]:
     # A general logic here is we only print device when it doesn't match
@@ -380,6 +380,11 @@
             suffixes.append('zero_point=' + str(self.q_per_channel_zero_points()))
             suffixes.append('axis=' + str(self.q_per_channel_axis()))
         tensor_str = _tensor_str(self.dequantize(), indent)
+    elif self.is_nested:
+        def indented_str(s, indent):
+            return "\n".join(f"  {line}" for line in s.split("\n"))
+        strs = ",\n".join(indented_str(str(t), indent + 1) for t in torch.ops.aten.unbind.int(self, 0))
+        tensor_str = f"[\n{strs}\n]"
     else:
         if self.is_meta:
             suffixes.append('size=' + str(tuple(self.shape)))
diff --git a/torch/nested/_nestedtensor.py b/torch/nested/_nestedtensor.py
index 4f16a31..6be365a 100644
--- a/torch/nested/_nestedtensor.py
+++ b/torch/nested/_nestedtensor.py
@@ -77,22 +77,7 @@
         return self._impl.is_contiguous()
 
     def __str__(self):
-        def _str(x, indent=0, tab="  "):
-            s = indent * tab + "[\n"
-            strs = list(map(str, x.unbind()))
-            strs = list(
-                map(
-                    lambda xi: "\n".join(
-                        map(lambda xij: (indent + 1) * tab + xij, xi.split("\n"))
-                    ),
-                    strs,
-                )
-            )
-            s += ",\n".join(strs)
-            s += "\n" + indent * tab + "]"
-            return s
-
-        return "nested_tensor(" + _str(self) + ")"
+        return str(self._impl)
 
     def __repr__(self):
         return self.__str__()