Add Shape Inference for Reshape Operator

Summary: Add shape inference for reshape. Because it cannot do shape inference for reshaped tensor with runtime tensor data, set `out[0].set_unknown_shape(true)` if no `shape` argument is used.

Differential Revision: D4671125

fbshipit-source-id: 685a9198f9b08e3336014c792f20051b381d8619
diff --git a/caffe2/operators/reshape_op.cc b/caffe2/operators/reshape_op.cc
index 974e73e..0837495 100644
--- a/caffe2/operators/reshape_op.cc
+++ b/caffe2/operators/reshape_op.cc
@@ -8,6 +8,101 @@
 OPERATOR_SCHEMA(Reshape)
     .NumInputs(1, 2)
     .NumOutputs(2)
+    .TensorInferenceFunction(
+        [](const OperatorDef& def, const vector<TensorShape>& in) {
+          vector<TensorShape> out(2);
+
+          // Do shape inference for old_shape
+          out[1].set_data_type(TensorProto::INT64);
+          out[1].add_dims(in[0].dims_size());
+
+          ArgumentHelper helper(def);
+          if (!helper.HasArgument("shape")) {
+            // Cannot do shape inference for reshaped tensor from runtime data.
+            CAFFE_ENFORCE_EQ(
+                in.size(),
+                2,
+                "New shape must be specified by either the input blob or the "
+                "argument `shape`.");
+            out[0].set_unknown_shape(true);
+            return out;
+          }
+          CAFFE_ENFORCE_EQ(
+              in.size(),
+              1,
+              "New shape must not be specified by the input blob and the "
+              "argument `shape` at the same time.");
+
+          // Infer the actual new shape
+          auto actualNewShape = helper.GetRepeatedArgument<int64_t>("shape");
+
+          // Copy over the dimensions for those that are specified zero
+          // and check the eligibility of input
+          for (int i = 0; i < actualNewShape.size(); ++i) {
+            CAFFE_ENFORCE_GE(
+                actualNewShape[i],
+                -1,
+                "The dimensions in argument `shape` "
+                "must not be a negative number.");
+
+            if (actualNewShape[i] == 0) {
+              CAFFE_ENFORCE_LT(
+                  i,
+                  in[0].dims_size(),
+                  "Argument `shape` has a dimension set to zero that exceeds "
+                  "the original dimension size.");
+              actualNewShape[i] = in[0].dims(i);
+            }
+          }
+
+          // Check if the new shape is valid and fills in the missing dimension
+          // specified by -1.
+          int64_t totalSize = 1;
+          for (const auto d : in[0].dims()) {
+            totalSize *= d;
+          }
+          int64_t size = 1;
+          int unknownIdx = -1;
+          for (int i = 0; i < actualNewShape.size(); ++i) {
+            const auto dim = actualNewShape[i];
+            if (dim == -1) {
+              CAFFE_ENFORCE(
+                  unknownIdx == -1,
+                  "Argument `shape` has more than one missing dimension.");
+              unknownIdx = i;
+            } else {
+              size *= dim;
+            }
+          }
+
+          if (unknownIdx != -1) {
+            CAFFE_ENFORCE(
+                totalSize % size == 0,
+                "Argument `shape` does not agree with the input data.",
+                " (",
+                totalSize,
+                " vs ",
+                size,
+                ")");
+            actualNewShape[unknownIdx] = totalSize / size;
+          } else {
+            CAFFE_ENFORCE_EQ(
+                totalSize,
+                size,
+                "Argument `shape` does not agree with the input data.",
+                " (",
+                totalSize,
+                " != ",
+                size,
+                ")");
+          }
+
+          out[0].set_data_type(in[0].data_type());
+          for (const auto d : actualNewShape) {
+            out[0].add_dims(d);
+          }
+          return out;
+        })
     .AllowInplace({{0, 0}})
     .SetDoc(R"DOC(
 Reshape the input tensor similar to numpy.reshape.
diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc
index 1a080de..9080bcc 100644
--- a/caffe2/operators/utility_ops.cc
+++ b/caffe2/operators/utility_ops.cc
@@ -113,6 +113,7 @@
             }
             total *= d;
           }
+          out[0].set_data_type(in[0].data_type());
           out[0].add_dims(in[0].dims(0));
           out[0].add_dims(total);
           return out;
@@ -139,6 +140,7 @@
           for (auto d : in[0].dims()) {
             total *= d;
           }
+          out[0].set_data_type(in[0].data_type());
           out[0].add_dims(total);
           return out;
         })
@@ -182,7 +184,7 @@
           return out;
         })
     .SetDoc(R"DOC(
-Produces tensor condaining data of first input and shape of second input.
+Produces tensor containing data of first input and shape of second input.
 )DOC")
     .Input(0, "data", "Tensor whose data will be copied into the output.")
     .Input(1, "shape_tensor", "Tensor whose shape will be applied to output.")
diff --git a/caffe2/python/operator_test/shape_inference_test.py b/caffe2/python/operator_test/shape_inference_test.py
index 2fe91d1..23869d5 100644
--- a/caffe2/python/operator_test/shape_inference_test.py
+++ b/caffe2/python/operator_test/shape_inference_test.py
@@ -279,6 +279,13 @@
 
         self.InferTensorRunAndCompare(model)
 
+    def testShapeInferenceReshape(self):
+        model = cnn.CNNModelHelper()
+        model.Reshape("X", ["Reshaped", "Old_Shape"], shape=[8, 0, -1, 2])
+        workspace.FeedBlob("X", np.random.rand(4, 26, 32).astype(np.float32))
+
+        self.InferTensorRunAndCompare(model)
+
     def InferTensorRunAndCompare(self, model):
         '''
         Runs shape inference, and then the model to check