[TensorExpr] Run constant pooling in fusion groups to dedupe constants. (#47402)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47402

Test Plan: Imported from OSS

Reviewed By: eellison

Differential Revision: D24740957

Pulled By: ZolotukhinM

fbshipit-source-id: 741cbddc4bf2decd95d444235c424a4ae003d0de
diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
index 793e755..86d5652 100644
--- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp
+++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
@@ -4,6 +4,7 @@
 #include <torch/csrc/jit/ir/alias_analysis.h>
 #include <torch/csrc/jit/jit_log.h>
 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
+#include <torch/csrc/jit/passes/constant_pooling.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/passes/pass_manager.h>
 #include <torch/csrc/jit/passes/remove_redundant_profiles.h>
@@ -607,6 +608,8 @@
       SubgraphUtils::unmergeSubgraph(n);
       return true;
     }
+    // Cleanup the subgraph from duplicated constants while we're at it.
+    ConstantPooling(subgraph);
     return false;
   }