[nnc] Do not fuse unsqueeze with variable dim (#58346)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58346
If `dim` is a variable, NNC doesn't know how to translate the result,
since the shape is unknown. This issue manifested as a `bad_variant_access`
when we try to pull an int constant out of that arg.
Note that, while the PE will pick up the resultant shape, it won't set guards accordingly.
ghstack-source-id: 129078971
Test Plan: new fuser test
Reviewed By: navahgar
Differential Revision: D28460956
fbshipit-source-id: 57ef918ef309ee57bfdf86717b910b6549750454
diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py
index 99e0f28..a1b42bf 100644
--- a/test/test_jit_fuser_te.py
+++ b/test/test_jit_fuser_te.py
@@ -1795,5 +1795,13 @@
script = self.checkScript(eager, (x, y))
self.assertAllFused(script.graph_for(x, y))
+ def test_unsqueeze_var_dim(self):
+ def eager(x, y, z: int):
+ return x * torch.unsqueeze(y, dim=z)
+ x = torch.rand(4, 4, 64).permute(1, 0, 2)
+ y = torch.rand(4, 4)
+ z = 2
+ script = self.checkScript(eager, (x, y, z))
+
if __name__ == '__main__':
run_tests()
diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
index 349b21f..6c4a8bb 100644
--- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp
+++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
@@ -1029,6 +1029,13 @@
}
}
+ if (node->kind() == aten::unsqueeze) {
+ // `dim` argument must be a constant.
+ if (node->input(1)->node()->kind() != prim::Constant) {
+ return false;
+ }
+ }
+
if (node->kind() == aten::conv2d) {
if (!tensorexpr::conv2dIsSupportedJit(node)) {
GRAPH_DEBUG("Params of conv2d are not supported");