[TensorExpr] Run constant pooling in fusion groups to dedupe constants. (#47402)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47402
Test Plan: Imported from OSS
Reviewed By: eellison
Differential Revision: D24740957
Pulled By: ZolotukhinM
fbshipit-source-id: 741cbddc4bf2decd95d444235c424a4ae003d0de
diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
index 793e755..86d5652 100644
--- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp
+++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
@@ -4,6 +4,7 @@
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
+#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/pass_manager.h>
#include <torch/csrc/jit/passes/remove_redundant_profiles.h>
@@ -607,6 +608,8 @@
SubgraphUtils::unmergeSubgraph(n);
return true;
}
+ // Cleanup the subgraph from duplicated constants while we're at it.
+ ConstantPooling(subgraph);
return false;
}