Track base of FunctionalTensor in inference mode.  (#135141)

The idea behind the tracking is the following, whenever we see a tensor if the tensors is a root tensors (does not have any view metas ) when we consider is as the base of the all the tensors that shares its storage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135141
Approved by: https://github.com/zou3519
diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp
index dfd4928..6f66e80 100644
--- a/aten/src/ATen/FunctionalTensorWrapper.cpp
+++ b/aten/src/ATen/FunctionalTensorWrapper.cpp
@@ -707,7 +707,12 @@
 }
 
 bool isFunctionalTensor(const at::Tensor& tensor) {
-  return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize);
+   return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize);
+}
+
+bool isBaseTensor(const at::Tensor& tensor) {
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(tensor));
+  return unsafeGetFunctionalWrapper(tensor)->isBaseTensor();
 }
 
 bool isFunctionalTensor(const std::optional<Tensor>& t) {
diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h
index afb3af5..ed5daf9 100644
--- a/aten/src/ATen/FunctionalTensorWrapper.h
+++ b/aten/src/ATen/FunctionalTensorWrapper.h
@@ -165,6 +165,12 @@
     was_storage_changed_ = true;
   }
 
+  // A FunctionalTensor is considered a base if its not a view of another
+  // tensor.
+  bool isBaseTensor() const {
+    return view_metas_.empty();
+  }
+
   c10::SymInt get_storage_size(bool before) {
     return functional_storage_impl()->get_storage_size(before);
   }
@@ -290,6 +296,8 @@
   return functional_impl;
 }
 
+TORCH_API bool isBaseTensor(const at::Tensor& tensor);
+
 TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
 TORCH_API bool isFunctionalTensor(const std::optional<Tensor>& t);
 TORCH_API bool isFunctionalTensor(
diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py
index b1e3ab1..019e88c 100644
--- a/test/inductor/test_auto_functionalize.py
+++ b/test/inductor/test_auto_functionalize.py
@@ -911,85 +911,20 @@
         with torch.inference_mode():
             self.test_auto_functionalize_extra1()
 
-    # In inference mode we do not support inplacing views yet.
     @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
     def test_inference_mode2_v2(self):
-        with torch.inference_mode(), torch.library._scoped_library(
-            "mylib", "FRAGMENT"
-        ) as lib:
-            torch.library.define(
-                "mylib::foo",
-                "(Tensor(a!) x, Tensor(b!) y) -> ()",
-                tags=torch.Tag.pt2_compliant_tag,
-                lib=lib,
-            )
+        with torch.inference_mode():
+            self.test_auto_functionalize_extra2()
 
-            @torch.library.impl("mylib::foo", "cpu", lib=lib)
-            @torch._dynamo.disable
-            def foo_impl(x, y):
-                x.sin_()
-                y.sin_()
+    @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
+    def test_inference_mode3_v2(self):
+        with torch.inference_mode():
+            self.test_auto_functionalize_extra3()
 
-            def f(x):
-                a = x[0]
-                b = x[1]
-                torch.ops.mylib.foo(a, b)
-                return
-
-            orig_args = [torch.randn(2)]
-
-            [aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args)
-            [inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args)
-            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
-            result3 = f(*eager_args)
-
-            self.assertEqual(inductor_args, eager_args)
-            self.assertEqual(inductor_args, aot_eager_args)
-
-            self.assertEqual(result3, result1)
-            self.assertEqual(result3, result2)
-
-            if torch._dynamo.config.assume_static_by_default:
-                self.assertExpectedInline(
-                    graph_aot,
-                    """\
-def forward(self, arg0_1: "f32[2][1]cpu"):
-        select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
-        select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1)
-        auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [select, select_1]);  select = select_1 = None
-        getitem_1: "f32[][]cpu" = auto_functionalized_v2[1]
-        getitem_2: "f32[][]cpu" = auto_functionalized_v2[2];  auto_functionalized_v2 = None
-        select_scatter: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, getitem_1, 0, 0);  getitem_1 = None
-        select_scatter_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter, getitem_2, 0, 1);  select_scatter = getitem_2 = None
-        copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_1);  arg0_1 = select_scatter_1 = copy_ = None
-        return ()""",  # noqa: B950
-                    ignore_comments=True,
-                    ignore_empty_lines=True,
-                )
-
-            # 2. Run with inductor backend
-
-            if torch._dynamo.config.assume_static_by_default:
-                self.assertExpectedInline(
-                    graph_inductor,
-                    """\
-def forward(self, arg0_1: "f32[2][1]cpu"):
-        select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
-        select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1)
-        as_strided_default: "f32[1][1]cpu" = torch.ops.aten.as_strided.default(select, [1], [1], 0);  select = None
-        clone_default: "f32[1][1]cpu" = torch.ops.aten.clone.default(as_strided_default);  as_strided_default = None
-        as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(clone_default, [], [], 0);  clone_default = None
-        as_strided_default_2: "f32[2][1]cpu" = torch.ops.aten.as_strided.default(select_1, [2], [1], 0);  select_1 = None
-        clone_default_1: "f32[2][1]cpu" = torch.ops.aten.clone.default(as_strided_default_2);  as_strided_default_2 = None
-        as_strided_default_3: "f32[][]cpu" = torch.ops.aten.as_strided.default(clone_default_1, [], [], 1);  clone_default_1 = None
-        foo_default = torch.ops.mylib.foo.default(as_strided_default_1, as_strided_default_3);  foo_default = None
-        select_scatter_default: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, as_strided_default_1, 0, 0);  as_strided_default_1 = None
-        select_scatter_default_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter_default, as_strided_default_3, 0, 1);  select_scatter_default = as_strided_default_3 = None
-        copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_default_1);  arg0_1 = select_scatter_default_1 = copy_ = None
-        return ()""",  # noqa: B950
-                    ignore_comments=True,
-                    ignore_empty_lines=True,
-                )
+    @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
+    def test_inference_mode4_v2(self):
+        with torch.inference_mode():
+            self.test_auto_functionalize_extra4()
 
     @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
     def test_dynamic_v2(self):
diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py
index e4e15bc..bd2fcee 100644
--- a/tools/pyi/gen_pyi.py
+++ b/tools/pyi/gen_pyi.py
@@ -781,6 +781,9 @@
             "_is_functional_tensor": [
                 "def _is_functional_tensor(t: Tensor) -> _bool: ..."
             ],
+            "_is_functional_tensor_base": [
+                "def _is_functional_tensor_base(t: Tensor) -> _bool: ..."
+            ],
             "_from_functional_tensor": [
                 "def _from_functional_tensor(t: Tensor) -> Tensor: ..."
             ],
diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py
index 8466d42..232981f 100644
--- a/torch/_higher_order_ops/auto_functionalize.py
+++ b/torch/_higher_order_ops/auto_functionalize.py
@@ -18,6 +18,13 @@
 )
 
 
+def get_base(tensor):
+    if torch.is_inference_mode_enabled():
+        return tensor._inference_mode_base
+    else:
+        return tensor._base
+
+
 @dataclass
 class ViewInfo:
     base_index: int
@@ -68,7 +75,7 @@
 
         if tensor is None:
             kwargs[f"{prefix}_base_index"] = None
-        elif tensor._base is None:
+        elif get_base(tensor) is None:
             # if the tensor is the base (not view), for simplicity we do not serialize view meta.
             kwargs[f"{prefix}_base_index"] = base_index
         else:
@@ -437,7 +444,7 @@
     arg_to_base_index: Dict[str, Any] = {}
 
     def update_dict(tensor, arg_name, index=None):
-        base = tensor if tensor._base is None else tensor._base
+        base = tensor if get_base(tensor) is None else get_base(tensor)
 
         def set_result(base_index):
             if index is None:
diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py
index e9cbb07..7bfa16a 100644
--- a/torch/_subclasses/functional_tensor.py
+++ b/torch/_subclasses/functional_tensor.py
@@ -1,6 +1,7 @@
 # mypy: allow-untyped-defs
 import contextlib
 import warnings
+import weakref
 from abc import ABC, abstractmethod
 from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union
 
@@ -111,7 +112,10 @@
         torch.ops.aten.unsafe_chunk.default,  # type: ignore[has-type]
     ]
 
-    def __new__(cls, elem):
+    # Used by auto_functionalize to determine base of tensors during inference mode.
+    _inference_mode_base: Optional["FunctionalTensor"] = None
+
+    def __new__(cls, elem, mode):
         assert torch._is_functional_tensor(elem)
 
         # In general, we'd like our functional tensor subclass to only be in charge of functionalization,
@@ -142,9 +146,9 @@
             cls,
             elem.shape,  # sizes
             elem.stride() if not is_sparse_any(elem) else None,  # strides
-            elem.storage_offset()
-            if not is_sparse_any(elem)
-            else None,  # storage_offset
+            (
+                elem.storage_offset() if not is_sparse_any(elem) else None
+            ),  # storage_offset
             None,  # memory_format
             elem.dtype,  # dtype
             elem.layout,  # layout
@@ -158,6 +162,21 @@
         )
         torch._C._set_throw_on_mutable_data_ptr(out)
         out.elem = elem
+
+        if (
+            torch.is_inference_mode_enabled()
+            and torch._inductor.config.enable_auto_functionalized_v2
+        ):
+            if out.is_base_tensor():
+                out._inference_mode_base = None
+                # This assumes that the FunctionalTensor.elem does not change its storage after this point.
+                # Otherwise this would be invalid.
+                mode._storage_to_base[out.elem.untyped_storage()] = out
+            else:
+                out._inference_mode_base = mode._storage_to_base[
+                    out.elem.untyped_storage()
+                ]
+                assert out._inference_mode_base is not None
         return out
 
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
@@ -209,6 +228,7 @@
     @staticmethod
     def to_functional(x):
         # We will do the wrapping for the user.
+
         assert not torch._is_functional_tensor(x)
         # The only autograd metadata we care about on the FunctionalTensor is:
         # - requires_grad (so autograd runs)
@@ -226,7 +246,7 @@
 
         with functional_mode:
             torch._mirror_autograd_meta_to(x, x_functional)  # type: ignore[attr-defined]
-            out = FunctionalTensor(x_functional)
+            out = FunctionalTensor(x_functional, functional_mode)
             torch._mirror_autograd_meta_to(x_functional, out)  # type: ignore[attr-defined]
         return out
 
@@ -234,6 +254,9 @@
         torch._sync(self)
         return torch._from_functional_tensor(self.elem)
 
+    def is_base_tensor(self) -> bool:
+        return torch._is_functional_tensor_base(self.elem)
+
     def replace_(self, output) -> None:
         torch._functionalize_replace(self.elem, output)
 
@@ -316,6 +339,10 @@
         # discovery. This flag distinguishes between the two stages.
         self._allow_token_discovery = _allow_token_discovery
 
+        self._storage_to_base: weakref.WeakKeyDictionary[
+            torch.storage.UntypedStorage, Optional[FunctionalTensor]
+        ] = weakref.WeakKeyDictionary()
+
     # No-op if FunctionalTensorMode is already in use
     def __enter__(self):
         def _get_prev_mode():
@@ -366,6 +393,7 @@
             if not issubclass(t, torch._subclasses.FakeTensor)
             and t not in [torch.Tensor, FunctionalTensor]
         ]
+
         if unrecognized_types:
             not_implemented_log.debug(
                 "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types
@@ -417,16 +445,13 @@
                 if r is not NotImplemented:
                     return r
 
-        def assert_is_functional(x):
-            assert torch._is_functional_tensor(x)
-
         def wrap(x):
             # Only wrap our outputs in subclasses if the inner functionalization call
             # also wrapped outputs into FunctionalTensorWrappers.
             # When can this happen? e.g. `torch.div(2, 2)`
             assert not isinstance(x, FunctionalTensor)
             if isinstance(x, torch.Tensor) and torch._is_functional_tensor(x):
-                return FunctionalTensor(x)
+                return FunctionalTensor(x, self)
             return x
 
         def unwrap(x):
diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp
index 1feb1f4..92890a1 100644
--- a/torch/csrc/autograd/python_torch_functions_manual.cpp
+++ b/torch/csrc/autograd/python_torch_functions_manual.cpp
@@ -664,6 +664,10 @@
             !at::functionalization::impl::isFunctionalTensor(o));
         at::functionalization::impl::replace_(t, o);
       });
+  py_module.def("_is_functional_tensor_base", [](const at::Tensor& t) {
+    TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
+    return at::functionalization::impl::isBaseTensor(t);
+  });
   py_module.def("_functionalize_is_multi_output_view", [](const at::Tensor& t) {
     TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
     auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);