[jit] Made a list for element-wise ops. (#59579)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59579

Test Plan: Imported from OSS

Reviewed By: eellison

Differential Revision: D28955319

Pulled By: navahgar

fbshipit-source-id: 605531aedf9250a226b0401d55fda3427bdc6f33
diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
index 32983f9..2909428 100644
--- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp
+++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
@@ -164,21 +164,28 @@
       "aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor",
       "aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor",
       "aten::where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor",
-      "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
       // TODO: enable other min/max variants, operators that can be both
       // elementwise or reductions:
       "aten::min.other(Tensor self, Tensor other) -> Tensor",
       "aten::max.other(Tensor self, Tensor other) -> Tensor",
       // TODO: enable slice, shape inference is not implemented for this op yet
-
-      "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
-      "aten::matmul(Tensor self, Tensor other) -> Tensor",
   };
   // clang-format on
 
   return supported_eltwise_set;
 }
 
+static const OperatorSet& supported_non_eltwise_set() {
+  // clang-format off
+  static const OperatorSet supported_non_eltwise_set{
+      "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
+      "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
+      "aten::matmul(Tensor self, Tensor other) -> Tensor",
+  };
+  // clang-format on
+  return supported_non_eltwise_set;
+};
+
 bool isSupported(Node* node) {
   // For Block codegen we allow limited ops.
   if (tensorexpr::getTEGenerateBlockCode()) {
@@ -198,6 +205,7 @@
   // clang-format on
 
   if (node->isMemberOf(supported_eltwise_set()) ||
+      node->isMemberOf(supported_non_eltwise_set()) ||
       node->isMemberOf(supported_misc_set) ||
       (texpr_reductions_enabled && node->isMemberOf(supported_reduction_set))) {
     // We only insert guards on Tensor types, so we rely on the output
@@ -527,9 +535,11 @@
         continue;
       }
 
-      // we only support shape calculations for elementwise and
+      // we only support shape calculations for elementwise, some
+      // non-elementwise like batch_norm, conv, matmul, and
       // a few exceptions (e.g. prim::ConstantChunk, etc) listed above
-      if (!n->isMemberOf(tensorexpr::supported_eltwise_set())) {
+      if (!n->isMemberOf(tensorexpr::supported_eltwise_set()) &&
+          !n->isMemberOf(tensorexpr::supported_non_eltwise_set())) {
         continue;
       }