Add Shape Inference for Reshape Operator
Summary: Add shape inference for reshape. Because it cannot do shape inference for reshaped tensor with runtime tensor data, set `out[0].set_unknown_shape(true)` if no `shape` argument is used.
Differential Revision: D4671125
fbshipit-source-id: 685a9198f9b08e3336014c792f20051b381d8619
diff --git a/caffe2/operators/reshape_op.cc b/caffe2/operators/reshape_op.cc
index 974e73e..0837495 100644
--- a/caffe2/operators/reshape_op.cc
+++ b/caffe2/operators/reshape_op.cc
@@ -8,6 +8,101 @@
OPERATOR_SCHEMA(Reshape)
.NumInputs(1, 2)
.NumOutputs(2)
+ .TensorInferenceFunction(
+ [](const OperatorDef& def, const vector<TensorShape>& in) {
+ vector<TensorShape> out(2);
+
+ // Do shape inference for old_shape
+ out[1].set_data_type(TensorProto::INT64);
+ out[1].add_dims(in[0].dims_size());
+
+ ArgumentHelper helper(def);
+ if (!helper.HasArgument("shape")) {
+ // Cannot do shape inference for reshaped tensor from runtime data.
+ CAFFE_ENFORCE_EQ(
+ in.size(),
+ 2,
+ "New shape must be specified by either the input blob or the "
+ "argument `shape`.");
+ out[0].set_unknown_shape(true);
+ return out;
+ }
+ CAFFE_ENFORCE_EQ(
+ in.size(),
+ 1,
+ "New shape must not be specified by the input blob and the "
+ "argument `shape` at the same time.");
+
+ // Infer the actual new shape
+ auto actualNewShape = helper.GetRepeatedArgument<int64_t>("shape");
+
+ // Copy over the dimensions for those that are specified zero
+ // and check the eligibility of input
+ for (int i = 0; i < actualNewShape.size(); ++i) {
+ CAFFE_ENFORCE_GE(
+ actualNewShape[i],
+ -1,
+ "The dimensions in argument `shape` "
+ "must not be a negative number.");
+
+ if (actualNewShape[i] == 0) {
+ CAFFE_ENFORCE_LT(
+ i,
+ in[0].dims_size(),
+ "Argument `shape` has a dimension set to zero that exceeds "
+ "the original dimension size.");
+ actualNewShape[i] = in[0].dims(i);
+ }
+ }
+
+ // Check if the new shape is valid and fills in the missing dimension
+ // specified by -1.
+ int64_t totalSize = 1;
+ for (const auto d : in[0].dims()) {
+ totalSize *= d;
+ }
+ int64_t size = 1;
+ int unknownIdx = -1;
+ for (int i = 0; i < actualNewShape.size(); ++i) {
+ const auto dim = actualNewShape[i];
+ if (dim == -1) {
+ CAFFE_ENFORCE(
+ unknownIdx == -1,
+ "Argument `shape` has more than one missing dimension.");
+ unknownIdx = i;
+ } else {
+ size *= dim;
+ }
+ }
+
+ if (unknownIdx != -1) {
+ CAFFE_ENFORCE(
+ totalSize % size == 0,
+ "Argument `shape` does not agree with the input data.",
+ " (",
+ totalSize,
+ " vs ",
+ size,
+ ")");
+ actualNewShape[unknownIdx] = totalSize / size;
+ } else {
+ CAFFE_ENFORCE_EQ(
+ totalSize,
+ size,
+ "Argument `shape` does not agree with the input data.",
+ " (",
+ totalSize,
+ " != ",
+ size,
+ ")");
+ }
+
+ out[0].set_data_type(in[0].data_type());
+ for (const auto d : actualNewShape) {
+ out[0].add_dims(d);
+ }
+ return out;
+ })
.AllowInplace({{0, 0}})
.SetDoc(R"DOC(
Reshape the input tensor similar to numpy.reshape.
diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc
index 1a080de..9080bcc 100644
--- a/caffe2/operators/utility_ops.cc
+++ b/caffe2/operators/utility_ops.cc
@@ -113,6 +113,7 @@
}
total *= d;
}
+ out[0].set_data_type(in[0].data_type());
out[0].add_dims(in[0].dims(0));
out[0].add_dims(total);
return out;
@@ -139,6 +140,7 @@
for (auto d : in[0].dims()) {
total *= d;
}
+ out[0].set_data_type(in[0].data_type());
out[0].add_dims(total);
return out;
})
@@ -182,7 +184,7 @@
return out;
})
.SetDoc(R"DOC(
-Produces tensor condaining data of first input and shape of second input.
+Produces tensor containing data of first input and shape of second input.
)DOC")
.Input(0, "data", "Tensor whose data will be copied into the output.")
.Input(1, "shape_tensor", "Tensor whose shape will be applied to output.")
diff --git a/caffe2/python/operator_test/shape_inference_test.py b/caffe2/python/operator_test/shape_inference_test.py
index 2fe91d1..23869d5 100644
--- a/caffe2/python/operator_test/shape_inference_test.py
+++ b/caffe2/python/operator_test/shape_inference_test.py
@@ -279,6 +279,13 @@
self.InferTensorRunAndCompare(model)
+ def testShapeInferenceReshape(self):
+ model = cnn.CNNModelHelper()
+ model.Reshape("X", ["Reshaped", "Old_Shape"], shape=[8, 0, -1, 2])
+ workspace.FeedBlob("X", np.random.rand(4, 26, 32).astype(np.float32))
+
+ self.InferTensorRunAndCompare(model)
+
def InferTensorRunAndCompare(self, model):
'''
Runs shape inference, and then the model to check