[TensorExpr fuser] Guard nodes that have tensor output properties determined by non-tensor inputs (#44137)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44137

We only insert guards on Tensor types, so we rely on the output
of a node being uniquely determined by its input types.
bail if any non-Tensor input affects the output type
and cannot be reasoned about statically

Test Plan: Imported from OSS

Reviewed By: bertmaher

Differential Revision: D23543602

Pulled By: eellison

fbshipit-source-id: abd6fe0b1fd7fe6fc251694d4cd442b19c032dd7
diff --git a/test/jit/test_profiler.py b/test/jit/test_profiler.py
index 1e72c29..f9a38cc 100644
--- a/test/jit/test_profiler.py
+++ b/test/jit/test_profiler.py
@@ -23,6 +23,7 @@
         torch._C._jit_set_texpr_fuser_enabled(True)
         torch._C._jit_override_can_fuse_on_cpu(True)
         self.default_dtype = torch.get_default_dtype()
+        self.old_reduction_enabled = torch._C._jit_set_texpr_reductions_enabled(True)
         torch.set_default_dtype(torch.double)
 
 
@@ -33,6 +34,34 @@
         torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
         torch._C._jit_override_can_fuse_on_cpu(self.can_fuse_on_cpu)
         torch.set_default_dtype(self.default_dtype)
+        torch._C._jit_set_texpr_reductions_enabled(self.old_reduction_enabled)
+
+    def test_tensor_type_not_determined_by_inputs(self):
+        @torch.jit.script
+        def scalar_type_input(x, y, z):
+            return x + y + 4 + z.item()
+
+        x = torch.tensor([2, 2])
+        scalar_type_input(x, x, torch.tensor(1))
+        scalar_type_input(x, x, torch.tensor(1))
+        scalar_type_input(x, x, torch.tensor(1.0))
+        g = torch.jit.last_executed_optimized_graph()
+
+        # item & add should not get pulled into the fusion group -
+        # we expect to see Fusion Group (item / add) Fusion Group in ir dump
+        FileCheck().check("TensorExpr").check("Scalar = aten::item").check_next("Tensor = aten::add").check("TensorExpr").run(g)
+
+
+        @torch.jit.script
+        def non_const_dtype(x, y, cond: bool):
+            dtype = torch.int16 if cond else torch.int32
+            return (x + y + 3).sum(dtype=dtype)
+
+        non_const_dtype(x, x, True)
+        non_const_dtype(x, x, True)
+        g = torch.jit.last_executed_optimized_graph()
+        # because dtype is non-const, sum should not get pulled into the Fusion Group
+        FileCheck().check("TensorExpr").check("TensorExpr").check_not("aten::sum").run(g)
 
     def test_specialize_backward(self):
         def test_fuse(a, b):
diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
index 030d4b2..7290a28 100644
--- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp
+++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
@@ -136,10 +136,29 @@
   };
   // clang-format on
 
-  if (node->isMemberOf(supported_operator_set)) {
-    return true;
-  }
-  if (texpr_reductions_enabled && node->isMemberOf(supported_reduction_set)) {
+  if (node->isMemberOf(supported_operator_set) ||
+      (texpr_reductions_enabled && node->isMemberOf(supported_reduction_set))) {
+    // We only insert guards on Tensor types, so we rely on the output
+    // of a node being uniquely determined by its input types.
+    // bail if any non-Tensor input affects the output type
+    // and cannot be reasoned about statically
+
+    // Value is either an int or a float (can occur from .item())
+    for (Value* v : node->inputs()) {
+      if (v->type()->cast<NumberType>()) {
+        return false;
+      }
+    }
+
+    // non-const dtype / device
+    for (auto arg_name : {"dtype", "device"}) {
+      if (auto index = node->schema().argumentIndexWithName(arg_name)) {
+        if (!toIValue(node->input(*index))) {
+          return false;
+        }
+      }
+    }
+
     return true;
   }