Track base of FunctionalTensor in inference mode. (#135141)
The idea behind the tracking is the following, whenever we see a tensor if the tensors is a root tensors (does not have any view metas ) when we consider is as the base of the all the tensors that shares its storage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135141
Approved by: https://github.com/zou3519
diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp
index dfd4928..6f66e80 100644
--- a/aten/src/ATen/FunctionalTensorWrapper.cpp
+++ b/aten/src/ATen/FunctionalTensorWrapper.cpp
@@ -707,7 +707,12 @@
}
bool isFunctionalTensor(const at::Tensor& tensor) {
- return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize);
+ return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize);
+}
+
+bool isBaseTensor(const at::Tensor& tensor) {
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(tensor));
+ return unsafeGetFunctionalWrapper(tensor)->isBaseTensor();
}
bool isFunctionalTensor(const std::optional<Tensor>& t) {
diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h
index afb3af5..ed5daf9 100644
--- a/aten/src/ATen/FunctionalTensorWrapper.h
+++ b/aten/src/ATen/FunctionalTensorWrapper.h
@@ -165,6 +165,12 @@
was_storage_changed_ = true;
}
+ // A FunctionalTensor is considered a base if its not a view of another
+ // tensor.
+ bool isBaseTensor() const {
+ return view_metas_.empty();
+ }
+
c10::SymInt get_storage_size(bool before) {
return functional_storage_impl()->get_storage_size(before);
}
@@ -290,6 +296,8 @@
return functional_impl;
}
+TORCH_API bool isBaseTensor(const at::Tensor& tensor);
+
TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
TORCH_API bool isFunctionalTensor(const std::optional<Tensor>& t);
TORCH_API bool isFunctionalTensor(
diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py
index b1e3ab1..019e88c 100644
--- a/test/inductor/test_auto_functionalize.py
+++ b/test/inductor/test_auto_functionalize.py
@@ -911,85 +911,20 @@
with torch.inference_mode():
self.test_auto_functionalize_extra1()
- # In inference mode we do not support inplacing views yet.
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
def test_inference_mode2_v2(self):
- with torch.inference_mode(), torch.library._scoped_library(
- "mylib", "FRAGMENT"
- ) as lib:
- torch.library.define(
- "mylib::foo",
- "(Tensor(a!) x, Tensor(b!) y) -> ()",
- tags=torch.Tag.pt2_compliant_tag,
- lib=lib,
- )
+ with torch.inference_mode():
+ self.test_auto_functionalize_extra2()
- @torch.library.impl("mylib::foo", "cpu", lib=lib)
- @torch._dynamo.disable
- def foo_impl(x, y):
- x.sin_()
- y.sin_()
+ @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
+ def test_inference_mode3_v2(self):
+ with torch.inference_mode():
+ self.test_auto_functionalize_extra3()
- def f(x):
- a = x[0]
- b = x[1]
- torch.ops.mylib.foo(a, b)
- return
-
- orig_args = [torch.randn(2)]
-
- [aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args)
- [inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args)
- eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
- result3 = f(*eager_args)
-
- self.assertEqual(inductor_args, eager_args)
- self.assertEqual(inductor_args, aot_eager_args)
-
- self.assertEqual(result3, result1)
- self.assertEqual(result3, result2)
-
- if torch._dynamo.config.assume_static_by_default:
- self.assertExpectedInline(
- graph_aot,
- """\
-def forward(self, arg0_1: "f32[2][1]cpu"):
- select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
- select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1)
- auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [select, select_1]); select = select_1 = None
- getitem_1: "f32[][]cpu" = auto_functionalized_v2[1]
- getitem_2: "f32[][]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
- select_scatter: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, getitem_1, 0, 0); getitem_1 = None
- select_scatter_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter, getitem_2, 0, 1); select_scatter = getitem_2 = None
- copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_1); arg0_1 = select_scatter_1 = copy_ = None
- return ()""", # noqa: B950
- ignore_comments=True,
- ignore_empty_lines=True,
- )
-
- # 2. Run with inductor backend
-
- if torch._dynamo.config.assume_static_by_default:
- self.assertExpectedInline(
- graph_inductor,
- """\
-def forward(self, arg0_1: "f32[2][1]cpu"):
- select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
- select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1)
- as_strided_default: "f32[1][1]cpu" = torch.ops.aten.as_strided.default(select, [1], [1], 0); select = None
- clone_default: "f32[1][1]cpu" = torch.ops.aten.clone.default(as_strided_default); as_strided_default = None
- as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(clone_default, [], [], 0); clone_default = None
- as_strided_default_2: "f32[2][1]cpu" = torch.ops.aten.as_strided.default(select_1, [2], [1], 0); select_1 = None
- clone_default_1: "f32[2][1]cpu" = torch.ops.aten.clone.default(as_strided_default_2); as_strided_default_2 = None
- as_strided_default_3: "f32[][]cpu" = torch.ops.aten.as_strided.default(clone_default_1, [], [], 1); clone_default_1 = None
- foo_default = torch.ops.mylib.foo.default(as_strided_default_1, as_strided_default_3); foo_default = None
- select_scatter_default: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, as_strided_default_1, 0, 0); as_strided_default_1 = None
- select_scatter_default_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter_default, as_strided_default_3, 0, 1); select_scatter_default = as_strided_default_3 = None
- copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_default_1); arg0_1 = select_scatter_default_1 = copy_ = None
- return ()""", # noqa: B950
- ignore_comments=True,
- ignore_empty_lines=True,
- )
+ @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
+ def test_inference_mode4_v2(self):
+ with torch.inference_mode():
+ self.test_auto_functionalize_extra4()
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
def test_dynamic_v2(self):
diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py
index e4e15bc..bd2fcee 100644
--- a/tools/pyi/gen_pyi.py
+++ b/tools/pyi/gen_pyi.py
@@ -781,6 +781,9 @@
"_is_functional_tensor": [
"def _is_functional_tensor(t: Tensor) -> _bool: ..."
],
+ "_is_functional_tensor_base": [
+ "def _is_functional_tensor_base(t: Tensor) -> _bool: ..."
+ ],
"_from_functional_tensor": [
"def _from_functional_tensor(t: Tensor) -> Tensor: ..."
],
diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py
index 8466d42..232981f 100644
--- a/torch/_higher_order_ops/auto_functionalize.py
+++ b/torch/_higher_order_ops/auto_functionalize.py
@@ -18,6 +18,13 @@
)
+def get_base(tensor):
+ if torch.is_inference_mode_enabled():
+ return tensor._inference_mode_base
+ else:
+ return tensor._base
+
+
@dataclass
class ViewInfo:
base_index: int
@@ -68,7 +75,7 @@
if tensor is None:
kwargs[f"{prefix}_base_index"] = None
- elif tensor._base is None:
+ elif get_base(tensor) is None:
# if the tensor is the base (not view), for simplicity we do not serialize view meta.
kwargs[f"{prefix}_base_index"] = base_index
else:
@@ -437,7 +444,7 @@
arg_to_base_index: Dict[str, Any] = {}
def update_dict(tensor, arg_name, index=None):
- base = tensor if tensor._base is None else tensor._base
+ base = tensor if get_base(tensor) is None else get_base(tensor)
def set_result(base_index):
if index is None:
diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py
index e9cbb07..7bfa16a 100644
--- a/torch/_subclasses/functional_tensor.py
+++ b/torch/_subclasses/functional_tensor.py
@@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import contextlib
import warnings
+import weakref
from abc import ABC, abstractmethod
from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union
@@ -111,7 +112,10 @@
torch.ops.aten.unsafe_chunk.default, # type: ignore[has-type]
]
- def __new__(cls, elem):
+ # Used by auto_functionalize to determine base of tensors during inference mode.
+ _inference_mode_base: Optional["FunctionalTensor"] = None
+
+ def __new__(cls, elem, mode):
assert torch._is_functional_tensor(elem)
# In general, we'd like our functional tensor subclass to only be in charge of functionalization,
@@ -142,9 +146,9 @@
cls,
elem.shape, # sizes
elem.stride() if not is_sparse_any(elem) else None, # strides
- elem.storage_offset()
- if not is_sparse_any(elem)
- else None, # storage_offset
+ (
+ elem.storage_offset() if not is_sparse_any(elem) else None
+ ), # storage_offset
None, # memory_format
elem.dtype, # dtype
elem.layout, # layout
@@ -158,6 +162,21 @@
)
torch._C._set_throw_on_mutable_data_ptr(out)
out.elem = elem
+
+ if (
+ torch.is_inference_mode_enabled()
+ and torch._inductor.config.enable_auto_functionalized_v2
+ ):
+ if out.is_base_tensor():
+ out._inference_mode_base = None
+ # This assumes that the FunctionalTensor.elem does not change its storage after this point.
+ # Otherwise this would be invalid.
+ mode._storage_to_base[out.elem.untyped_storage()] = out
+ else:
+ out._inference_mode_base = mode._storage_to_base[
+ out.elem.untyped_storage()
+ ]
+ assert out._inference_mode_base is not None
return out
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
@@ -209,6 +228,7 @@
@staticmethod
def to_functional(x):
# We will do the wrapping for the user.
+
assert not torch._is_functional_tensor(x)
# The only autograd metadata we care about on the FunctionalTensor is:
# - requires_grad (so autograd runs)
@@ -226,7 +246,7 @@
with functional_mode:
torch._mirror_autograd_meta_to(x, x_functional) # type: ignore[attr-defined]
- out = FunctionalTensor(x_functional)
+ out = FunctionalTensor(x_functional, functional_mode)
torch._mirror_autograd_meta_to(x_functional, out) # type: ignore[attr-defined]
return out
@@ -234,6 +254,9 @@
torch._sync(self)
return torch._from_functional_tensor(self.elem)
+ def is_base_tensor(self) -> bool:
+ return torch._is_functional_tensor_base(self.elem)
+
def replace_(self, output) -> None:
torch._functionalize_replace(self.elem, output)
@@ -316,6 +339,10 @@
# discovery. This flag distinguishes between the two stages.
self._allow_token_discovery = _allow_token_discovery
+ self._storage_to_base: weakref.WeakKeyDictionary[
+ torch.storage.UntypedStorage, Optional[FunctionalTensor]
+ ] = weakref.WeakKeyDictionary()
+
# No-op if FunctionalTensorMode is already in use
def __enter__(self):
def _get_prev_mode():
@@ -366,6 +393,7 @@
if not issubclass(t, torch._subclasses.FakeTensor)
and t not in [torch.Tensor, FunctionalTensor]
]
+
if unrecognized_types:
not_implemented_log.debug(
"FunctionalTensor unrecognized subclass(es): %s", unrecognized_types
@@ -417,16 +445,13 @@
if r is not NotImplemented:
return r
- def assert_is_functional(x):
- assert torch._is_functional_tensor(x)
-
def wrap(x):
# Only wrap our outputs in subclasses if the inner functionalization call
# also wrapped outputs into FunctionalTensorWrappers.
# When can this happen? e.g. `torch.div(2, 2)`
assert not isinstance(x, FunctionalTensor)
if isinstance(x, torch.Tensor) and torch._is_functional_tensor(x):
- return FunctionalTensor(x)
+ return FunctionalTensor(x, self)
return x
def unwrap(x):
diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp
index 1feb1f4..92890a1 100644
--- a/torch/csrc/autograd/python_torch_functions_manual.cpp
+++ b/torch/csrc/autograd/python_torch_functions_manual.cpp
@@ -664,6 +664,10 @@
!at::functionalization::impl::isFunctionalTensor(o));
at::functionalization::impl::replace_(t, o);
});
+ py_module.def("_is_functional_tensor_base", [](const at::Tensor& t) {
+ TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
+ return at::functionalization::impl::isBaseTensor(t);
+ });
py_module.def("_functionalize_is_multi_output_view", [](const at::Tensor& t) {
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);