fix set_() with functionalization (#90722)
This should fix https://github.com/pytorch/pytorch/issues/90573
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90722
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/templates/RegisterFunctionalization.cpp b/aten/src/ATen/templates/RegisterFunctionalization.cpp
index 7160856..d6cef51 100644
--- a/aten/src/ATen/templates/RegisterFunctionalization.cpp
+++ b/aten/src/ATen/templates/RegisterFunctionalization.cpp
@@ -2,6 +2,7 @@
// ${generated_comment}
#include <ATen/core/LegacyTypeDispatch.h>
+#include <ATen/EmptyTensor.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/FunctionalInverses.h>
#include <torch/library.h>
diff --git a/test/test_functionalization.py b/test/test_functionalization.py
index ec1a0ca..fd60de1 100644
--- a/test/test_functionalization.py
+++ b/test/test_functionalization.py
@@ -146,6 +146,17 @@
r = _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(2, 2))
self.assertEqual(r.stride(), (5, 1))
+ def test_set_(self):
+ def f(x):
+ y = torch.ones(2)
+ y.set_(x.storage())
+ return y
+
+ # We should probaby get the crossref test to work,
+ # but fixing it for Storage() objects is annoying.
+ r = _functionalize(f, reapply_views=True, crossref=False)(torch.ones(2))
+ self.assertEqual(str(r.device), 'cpu')
+
def test_view_clone_view_inplace(self):
def f(input):
shape = [1, 1024, 128, 128]
diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py
index 39b4962..558a6c9 100644
--- a/torchgen/gen_functionalization_type.py
+++ b/torchgen/gen_functionalization_type.py
@@ -596,10 +596,18 @@
)
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
+ # We don't want to run the inplace meta func for ops like .set_(), because:
+ # (1) they're unnecessary: inplace meta checks are only useful for ops like add_(),
+ # where broadcasting will work for the out-of-place case but should fail on the inplace call
+ # (2) They'll also fail without adding extra infra: we'd need to convert the input storage argument
+ # into a meta storage
+ any_storage_args = any(
+ a.type == BaseType(BaseTy.Storage) for a in f.func.arguments.flat_all
+ )
return f"""
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
- if ({str(f.func.kind() == SchemaKind.inplace).lower()}) {{
+ if ({str(not any_storage_args and f.func.kind() == SchemaKind.inplace).lower()}) {{
// Before converting the mutable op to its functional variant, run meta tensors through the original op.
// This will help us catch shape errors that apply to inplace ops that wouldn't apply to their functional variants.
// (We can only do this for inplace ops today though, because they technicaly all support meta tensors).