[ONNX] Improve Expand shape inference (#69264)
Extend shape inference support for `Expand`, when value of argument `shape` is unknown. Infer the rank of the output of `Expand`, and set shape to dynamic, if shape of argument `shape` is known.
Without this, shape inference aborts, and falls back to the static shape provided by tracer, which is incorrect in many cases.
Co-authored-by: BowenBao <bowbaomicrosoft.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72985
diff --git a/test/onnx/test_pytorch_onnx_shape_inference.py b/test/onnx/test_pytorch_onnx_shape_inference.py
index 9319670..3808de1 100644
--- a/test/onnx/test_pytorch_onnx_shape_inference.py
+++ b/test/onnx/test_pytorch_onnx_shape_inference.py
@@ -114,5 +114,14 @@
slice = g.op("Slice", input, start_input, end, axis, step)
self.run_test(g, slice.node(), expect_tensor(None, shape=(None, None)))
+ def test_expand(self):
+ g = self.create_empty_graph()
+ input = g.addInput()
+ constant = self.insert_tensor_constant(g, torch.ones(2, 4))
+ input.setType(constant.type().with_sizes([None, None]))
+ shape = g.op("Shape", input)
+ expand = g.op("Expand", constant, shape)
+ self.run_test(g, expand.node(), expect_tensor("Float", shape=(None, None)))
+
if __name__ == '__main__':
unittest.main()
diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp
index 7219d2d..167e401 100644
--- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp
+++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp
@@ -1374,6 +1374,8 @@
if (input0_shape_size.has_value()) {
auto input0_shape_value = input0_shape_size.value();
if (ConstantValueMap::HasValue(n->input(1)->debugName())) {
+ // When value of `shape` is statically known,
+ // output shape can be computed.
auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
n->input(1)->debugName());
auto final_shape =
@@ -1381,6 +1383,23 @@
if (final_shape.has_value()) {
UpdateShape(n->output(), final_shape.value());
}
+ } else if (
+ auto expand_shape =
+ ConstantValueMap::GetShapeInto1DInt64VectorWithOneUnknown(
+ n->input(1)->debugName())) {
+ // When shape of `shape` is statically known,
+ // output rank can be computed.
+ TORCH_INTERNAL_ASSERT(
+ expand_shape.value().size() == 1,
+ "`Shape` input to `Expand` should be a 1-D tensor. Instead got rank ",
+ expand_shape.value().size());
+ if (expand_shape.value()[0] > 0) {
+ std::vector<c10::ShapeSymbol> final_shape;
+ for (const auto i : c10::irange(expand_shape.value()[0])) {
+ final_shape.emplace_back(c10::ShapeSymbol::newSymbol());
+ }
+ UpdateShape(n->output(), c10::SymbolicShape(final_shape));
+ }
}
}
}