[jit] Made a list for element-wise ops. (#59579)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59579
Test Plan: Imported from OSS
Reviewed By: eellison
Differential Revision: D28955319
Pulled By: navahgar
fbshipit-source-id: 605531aedf9250a226b0401d55fda3427bdc6f33
diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
index 32983f9..2909428 100644
--- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp
+++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
@@ -164,21 +164,28 @@
"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor",
"aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor",
"aten::where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor",
- "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
// TODO: enable other min/max variants, operators that can be both
// elementwise or reductions:
"aten::min.other(Tensor self, Tensor other) -> Tensor",
"aten::max.other(Tensor self, Tensor other) -> Tensor",
// TODO: enable slice, shape inference is not implemented for this op yet
-
- "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
- "aten::matmul(Tensor self, Tensor other) -> Tensor",
};
// clang-format on
return supported_eltwise_set;
}
+static const OperatorSet& supported_non_eltwise_set() {
+ // clang-format off
+ static const OperatorSet supported_non_eltwise_set{
+ "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
+ "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
+ "aten::matmul(Tensor self, Tensor other) -> Tensor",
+ };
+ // clang-format on
+ return supported_non_eltwise_set;
+};
+
bool isSupported(Node* node) {
// For Block codegen we allow limited ops.
if (tensorexpr::getTEGenerateBlockCode()) {
@@ -198,6 +205,7 @@
// clang-format on
if (node->isMemberOf(supported_eltwise_set()) ||
+ node->isMemberOf(supported_non_eltwise_set()) ||
node->isMemberOf(supported_misc_set) ||
(texpr_reductions_enabled && node->isMemberOf(supported_reduction_set))) {
// We only insert guards on Tensor types, so we rely on the output
@@ -527,9 +535,11 @@
continue;
}
- // we only support shape calculations for elementwise and
+ // we only support shape calculations for elementwise, some
+ // non-elementwise like batch_norm, conv, matmul, and
// a few exceptions (e.g. prim::ConstantChunk, etc) listed above
- if (!n->isMemberOf(tensorexpr::supported_eltwise_set())) {
+ if (!n->isMemberOf(tensorexpr::supported_eltwise_set()) &&
+ !n->isMemberOf(tensorexpr::supported_non_eltwise_set())) {
continue;
}