[TensorExpr fuser] Guard nodes that have tensor output properties determined by non-tensor inputs (#44137)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44137
We only insert guards on Tensor types, so we rely on the output
of a node being uniquely determined by its input types.
bail if any non-Tensor input affects the output type
and cannot be reasoned about statically
Test Plan: Imported from OSS
Reviewed By: bertmaher
Differential Revision: D23543602
Pulled By: eellison
fbshipit-source-id: abd6fe0b1fd7fe6fc251694d4cd442b19c032dd7
diff --git a/test/jit/test_profiler.py b/test/jit/test_profiler.py
index 1e72c29..f9a38cc 100644
--- a/test/jit/test_profiler.py
+++ b/test/jit/test_profiler.py
@@ -23,6 +23,7 @@
torch._C._jit_set_texpr_fuser_enabled(True)
torch._C._jit_override_can_fuse_on_cpu(True)
self.default_dtype = torch.get_default_dtype()
+ self.old_reduction_enabled = torch._C._jit_set_texpr_reductions_enabled(True)
torch.set_default_dtype(torch.double)
@@ -33,6 +34,34 @@
torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
torch._C._jit_override_can_fuse_on_cpu(self.can_fuse_on_cpu)
torch.set_default_dtype(self.default_dtype)
+ torch._C._jit_set_texpr_reductions_enabled(self.old_reduction_enabled)
+
+ def test_tensor_type_not_determined_by_inputs(self):
+ @torch.jit.script
+ def scalar_type_input(x, y, z):
+ return x + y + 4 + z.item()
+
+ x = torch.tensor([2, 2])
+ scalar_type_input(x, x, torch.tensor(1))
+ scalar_type_input(x, x, torch.tensor(1))
+ scalar_type_input(x, x, torch.tensor(1.0))
+ g = torch.jit.last_executed_optimized_graph()
+
+ # item & add should not get pulled into the fusion group -
+ # we expect to see Fusion Group (item / add) Fusion Group in ir dump
+ FileCheck().check("TensorExpr").check("Scalar = aten::item").check_next("Tensor = aten::add").check("TensorExpr").run(g)
+
+
+ @torch.jit.script
+ def non_const_dtype(x, y, cond: bool):
+ dtype = torch.int16 if cond else torch.int32
+ return (x + y + 3).sum(dtype=dtype)
+
+ non_const_dtype(x, x, True)
+ non_const_dtype(x, x, True)
+ g = torch.jit.last_executed_optimized_graph()
+ # because dtype is non-const, sum should not get pulled into the Fusion Group
+ FileCheck().check("TensorExpr").check("TensorExpr").check_not("aten::sum").run(g)
def test_specialize_backward(self):
def test_fuse(a, b):
diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
index 030d4b2..7290a28 100644
--- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp
+++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
@@ -136,10 +136,29 @@
};
// clang-format on
- if (node->isMemberOf(supported_operator_set)) {
- return true;
- }
- if (texpr_reductions_enabled && node->isMemberOf(supported_reduction_set)) {
+ if (node->isMemberOf(supported_operator_set) ||
+ (texpr_reductions_enabled && node->isMemberOf(supported_reduction_set))) {
+ // We only insert guards on Tensor types, so we rely on the output
+ // of a node being uniquely determined by its input types.
+ // bail if any non-Tensor input affects the output type
+ // and cannot be reasoned about statically
+
+ // Value is either an int or a float (can occur from .item())
+ for (Value* v : node->inputs()) {
+ if (v->type()->cast<NumberType>()) {
+ return false;
+ }
+ }
+
+ // non-const dtype / device
+ for (auto arg_name : {"dtype", "device"}) {
+ if (auto index = node->schema().argumentIndexWithName(arg_name)) {
+ if (!toIValue(node->input(*index))) {
+ return false;
+ }
+ }
+ }
+
return true;
}