Support onnxifi with partially shaped inferred net (#16877)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16877
That's it.
Reviewed By: ipiszy
Differential Revision: D13997771
fbshipit-source-id: f512c7f30b4a4747aca335a0769712c2a2cc2206
diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc
index 7b8ee34..607d6f0 100644
--- a/caffe2/opt/onnxifi_transformer.cc
+++ b/caffe2/opt/onnxifi_transformer.cc
@@ -792,14 +792,22 @@
auto* shape_arg = net.add_arg();
shape_arg->set_name("input_shape_info");
for (const auto& i : op.input()) {
+ const auto it = shape_hints.find(i);
+ if (it == shape_hints.end()) {
+ return false;
+ }
shape_arg->mutable_tensors()->Add()->CopyFrom(
- WrapShapeInfoIntoTensorProto(i, shape_hints.at(i)));
+ WrapShapeInfoIntoTensorProto(i, it->second));
}
shape_arg = net.add_arg();
shape_arg->set_name("output_shape_info");
for (const auto& i : op.output()) {
+ const auto it = shape_hints.find(i);
+ if (it == shape_hints.end()) {
+ return false;
+ }
shape_arg->mutable_tensors()->Add()->CopyFrom(
- WrapShapeInfoIntoTensorProto(i, shape_hints.at(i)));
+ WrapShapeInfoIntoTensorProto(i, it->second));
}
std::string c2_model_str;