[JIT] Shape inference improvement (#35051)
Summary:
Support `aten::div` in `PropagateCompleteShapeOnNode`.
complete shape propagation on `aten::div` is disabled, because shape inference
relies on running node to propagate shape. For `aten::div` we run into
deviding-by-zero problem.
However, shape propagation for pointwise operatoins should be identical. We
would be able to swap the operation for `aten::div` with `aten::mul`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35051
Differential Revision: D20921359
Pulled By: eellison
fbshipit-source-id: 344371f34724a1b6bb2f853ebb4cef80423a4f9f
diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp
index 0d61ccd..3c8e083 100644
--- a/torch/csrc/jit/passes/shape_analysis.cpp
+++ b/torch/csrc/jit/passes/shape_analysis.cpp
@@ -397,10 +397,13 @@
return true;
}
- bool PropagateShapeOnNodeByRunningIt(Node* node) {
+ bool PropagateShapeOnNodeByRunningIt(Node* node, Operation op = nullptr) {
if (!canPropagateShapeByRunningIt(node))
return false;
- auto op = node->getOperation();
+
+ if (!op)
+ op = node->getOperation();
+
Stack stack;
for (auto input : node->inputs()) {
@@ -1829,12 +1832,21 @@
node->matches(
"aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
node->matches("aten::mul(Tensor self, Tensor other) -> Tensor")) {
- // These nodes and "div" handle tensors of different shapes internally,
- // so there's no need to insert explicit expand nodes. Note that "div" is
- // handled by the fallthrough because it's not always safe to run it due
- // to integer divide-by-zero.
+ // These nodes handle tensors of different shapes internally, so there's
+ // no need to insert explicit expand nodes.
return PropagateShapeOnNodeByRunningIt(node);
} else if (node->matches(
+ "aten::div(Tensor self, Tensor other) -> Tensor")) {
+ // "div" handle tensors of different shapes internally, so there's no need
+ // to insert explicit expand nodes.
+ // Note that this function could be merged to the one above , but "div" is
+ // not always safe to run by itself due to integer divide-by-zero.
+ // We fake the execution by running "mul" operation instead.
+ auto op = getOperatorForLiteral(
+ "aten::mul(Tensor self, Tensor other) -> Tensor")
+ ->getOperation();
+ return PropagateShapeOnNodeByRunningIt(node, op);
+ } else if (node->matches(
"aten::pow(Tensor self, Scalar exponent) -> Tensor")) {
node->output()->setType(tensor_types.at(0));
return true;
@@ -1843,6 +1855,7 @@
"aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor") ||
node->matches(
"aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor") ||
+ node->matches("aten::div(Tensor self, Scalar other) -> Tensor") ||
node->matches("aten::mul(Tensor self, Scalar other) -> Tensor")) {
auto first_scalar_type = (tensor_types)[0]->scalarType();
auto second_scalar_type =