fix set_() with functionalization (#90722)

This should fix https://github.com/pytorch/pytorch/issues/90573

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90722
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/templates/RegisterFunctionalization.cpp b/aten/src/ATen/templates/RegisterFunctionalization.cpp
index 7160856..d6cef51 100644
--- a/aten/src/ATen/templates/RegisterFunctionalization.cpp
+++ b/aten/src/ATen/templates/RegisterFunctionalization.cpp
@@ -2,6 +2,7 @@
 // ${generated_comment}
 
 #include <ATen/core/LegacyTypeDispatch.h>
+#include <ATen/EmptyTensor.h>
 #include <ATen/FunctionalTensorWrapper.h>
 #include <ATen/FunctionalInverses.h>
 #include <torch/library.h>
diff --git a/test/test_functionalization.py b/test/test_functionalization.py
index ec1a0ca..fd60de1 100644
--- a/test/test_functionalization.py
+++ b/test/test_functionalization.py
@@ -146,6 +146,17 @@
         r = _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(2, 2))
         self.assertEqual(r.stride(), (5, 1))
 
+    def test_set_(self):
+        def f(x):
+            y = torch.ones(2)
+            y.set_(x.storage())
+            return y
+
+        # We should probaby get the crossref test to work,
+        # but fixing it for Storage() objects is annoying.
+        r = _functionalize(f, reapply_views=True, crossref=False)(torch.ones(2))
+        self.assertEqual(str(r.device), 'cpu')
+
     def test_view_clone_view_inplace(self):
         def f(input):
             shape = [1, 1024, 128, 128]
diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py
index 39b4962..558a6c9 100644
--- a/torchgen/gen_functionalization_type.py
+++ b/torchgen/gen_functionalization_type.py
@@ -596,10 +596,18 @@
         )
 
     meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
+    # We don't want to run the inplace meta func for ops like .set_(), because:
+    # (1) they're unnecessary: inplace meta checks are only useful for ops like add_(),
+    #     where broadcasting will work for the out-of-place case but should fail on the inplace call
+    # (2) They'll also fail without adding extra infra: we'd need to convert the input storage argument
+    #     into a meta storage
+    any_storage_args = any(
+        a.type == BaseType(BaseTy.Storage) for a in f.func.arguments.flat_all
+    )
 
     return f"""
     {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
-      if ({str(f.func.kind() == SchemaKind.inplace).lower()}) {{
+      if ({str(not any_storage_args and f.func.kind() == SchemaKind.inplace).lower()}) {{
         // Before converting the mutable op to its functional variant, run meta tensors through the original op.
         // This will help us catch shape errors that apply to inplace ops that wouldn't apply to their functional variants.
         // (We can only do this for inplace ops today though, because they technicaly all support meta tensors).