[shape inference] add shape inference support

Summary:
* To make pruning op compatible with shape inference, we introduced a new quantile argument (as in D23463390) to differentiate dynamic/fixed pruning.

* The fixed pruning op has defined output shapes. However, the input shapes are not determined therefore we want to bypass the input shapes checking for two pruning ops, as implemented in this diff.

Test Plan:
buck test caffe2/caffe2/opt:bound_shape_inference_test

```
Started reporting to test run: https://our.intern.facebook.com/intern/testinfra/testrun/844425102187909
    ✓ ListingSuccess: caffe2/caffe2/opt:bound_shape_inference_test - main (1.973)
    ✓ Pass: caffe2/caffe2/opt:bound_shape_inference_test - BoundShapeInference.FC3D (2.604)
    ✓ Pass: caffe2/caffe2/opt:bound_shape_inference_test - BoundShapeInference.SparseLengthsSumFused4BitRowwise (2.635)
    ✓ Pass: caffe2/caffe2/opt:bound_shape_inference_test - BoundShapeInference.FC (2.690)
    ✓ Pass: caffe2/caffe2/opt:bound_shape_inference_test - BoundShapeInference.Int8QuantizeInferInputBackwards (2.705)
    ✓ Pass: caffe2/caffe2/opt:bound_shape_inference_test - BoundShapeInference.SparseLengthsSum (2.729)
    ✓ Pass: caffe2/caffe2/opt:bound_shape_inference_test - BoundShapeInference.Reshape (2.754)
    ✓ Pass: caffe2/caffe2/opt:bound_shape_inference_test - BoundShapeInference.ConcatMissingInput (2.770)
    ✓ Pass: caffe2/caffe2/opt:bound_shape_inference_test - BoundShapeInference.ElementwiseOp (2.770)
    ✓ Pass: caffe2/caffe2/opt:bound_shape_inference_test - BoundShapeInference.Tile (2.785)
    ✓ Pass: caffe2/caffe2/opt:bound_shape_inference_test - BoundShapeInference.Bucketize (2.789)
    ✓ Pass: caffe2/caffe2/opt:bound_shape_inference_test - BoundShapeInference.SparseLengthsSumFused8BitRowwise (2.807)
    ✓ Pass: caffe2/caffe2/opt:bound_shape_inference_test - BoundShapeInference.SparseLengthsSum8BitRowwiseSparse (2.841)
    ✓ Pass: caffe2/caffe2/opt:bound_shape_inference_test - BoundShapeInference.Split (2.863)
    ✓ Pass: caffe2/caffe2/opt:bound_shape_inference_test - BoundShapeInference.ConcatInferInputBackwards (2.894)
    ✓ Pass: caffe2/caffe2/opt:bound_shape_inference_test - BoundShapeInference.ElementwiseInferInputBackwards (2.898)
    ✓ Pass: caffe2/caffe2/opt:bound_shape_inference_test - BoundShapeInference.Combo0 (2.902)
    ✓ Pass: caffe2/caffe2/opt:bound_shape_inference_test - BoundShapeInference.LengthsRangeFill (2.964)
    ✓ Pass: caffe2/caffe2/opt:bound_shape_inference_test - BoundShapeInference.Quantization (2.964)
Summary
  Pass: 18
  ListingSuccess: 1
Finished test run: https://our.intern.facebook.com/intern/testinfra/testrun/844425102187909
```

buck test caffe2/caffe2/fb/opt:bound_shape_inference_net_test

```
 Started reporting to test run: https://our.intern.facebook.com/intern/testinfra/testrun/3096224780078093
    ✓ ListingSuccess: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - main (14.092)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.ClipLengths (15.508)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.DPER3IdListFeaturePreProcessing (15.521)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.ClipRanges (16.198)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.RowwisePrune (16.302)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - FbBoundShapeInferencerTest.GatherRanges1 (16.585)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.Combo3 (16.865)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.DPER3IdListFeaturePreProcessingWithCast (16.907)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.GatherRanges2 (16.921)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - FbBoundShapeInferencerTest.LengthsRangeFill (17.157)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.ClipRangesAndGatherRanges (17.277)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.DPER3IdScoreListFeaturePreProcessing (17.274)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.ClipRangesGatherSigridHash (17.554)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.Combo1 (17.645)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.DPER3IdScoreListFeaturePreProcessingDEFAULT (17.887)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.DPER3IdListFeaturePreProcessingDEFAULT (17.929)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.f97293388_0 (19.343)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - FbBoundShapeInferencerTest.GatherRangesToDense1 (19.489)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.DPER3IdScoreListFeaturePreProcessingWithCast (19.887)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.xray_v11 (19.905)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - FbBoundShapeInferencerTest.SigridTransforms (20.080)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.Combo2 (20.086)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.vanillaSparseNN (59.847)
    ✓ Pass: caffe2/caffe2/fb/opt:bound_shape_inference_net_test - BoundShapeInference.gather (97.822)
Summary
  Pass: 23
  ListingSuccess: 1
```

## Workflow testing

===
* non-DI/fixed quantile/user side/non-self-binning
f224250571

*  non-DI/fixed quantile/user+ad side/non-self-binning
f224250610

* DI/fixed quantile/user side/self-binning
f224250637

* DI/fixed quantile/user+ad side/self-binning
f224250662

*  non-DI/dynamic quantile/user+ad side/non-self-binning
f224250705

* DI/dynamic quantile/user+ad side/self-binning
f224250760

Reviewed By: ChunliF

Differential Revision: D23647390

fbshipit-source-id: 3ec1c0eaea53bd4d5eda4a0436577216f7fa8ead
diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc
index d8fe956..c513c1a 100644
--- a/caffe2/opt/bound_shape_inferencer.cc
+++ b/caffe2/opt/bound_shape_inferencer.cc
@@ -234,7 +234,8 @@
     bool is_quantized,
     bool allow_existing_shape,
     float scale,
-    int offset) {
+    int offset,
+    bool in_place_op) {
   auto rt = shape_info_.emplace(name, ShapeInfo());
   ShapeInfo& shape_info = rt.first->second;
   TensorShape& shape = shape_info.shape;
@@ -246,8 +247,8 @@
     shape_info.q_info.offset.push_back(offset);
     shape_info.q_info.axis = 1;
   }
-  // If the shape information exists in shape_info_ already
-  if (!rt.second) {
+  // If the shape information exists in shape_info_ already and we want to compare old/new shapes
+  if (!rt.second && !in_place_op) {
     // Check dim size consistency
     CAFFE_ENFORCE_EQ(
         shape.dims_size(),
@@ -290,13 +291,19 @@
     return shape;
   }
   // If shape information does not exist in shape_info_,
+  // or shape info is not final,
   // set shape info according to inputs.
-  shape_info.setDimType(t);
-  shape.mutable_dims()->Clear();
-  for (const auto d : bound_dims) {
-    shape.add_dims(d);
+  if (!shape_info.getShapeIsFinal()) {
+    shape_info.setDimType(t);
+    shape.mutable_dims()->Clear();
+    for (const auto d : bound_dims) {
+      shape.add_dims(d);
+    }
+    shape.set_data_type(type);
+    if (in_place_op) {
+      shape_info.setShapeIsFinal(true);
+    }
   }
-  shape.set_data_type(type);
   return shape;
 }
 
@@ -851,7 +858,12 @@
       false);
 }
 
-void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
+void BoundShapeInferencer::InferCommonOp(
+  const OperatorDef& op,
+  const OpSchema* schema,
+  bool bypass_input_check,
+  bool in_place_op
+) {
   // First, we need to check that all the input shape/types are already
   // presented
   try {
@@ -859,25 +871,30 @@
         types_with_independent_output_shape = {"Int8GenQuantParams",
                                                "Int8QuantSchemeBlobFill",
                                                "ComputeEqualizationScale"};
+    const static std::unordered_set<std::string>
+        pruning_ops = {"RowwisePruneI64", "RowwisePruneI32"};
     std::vector<TensorShape> input_shapes;
     for (const auto& input : op.input()) {
       const auto it = shape_info_.find(input);
       if (it == shape_info_.end() &&
-          !types_with_independent_output_shape.count(op.type())) {
+          !types_with_independent_output_shape.count(op.type()) && !bypass_input_check) {
         LOG(WARNING) << "Cannot find shape info for " << input << ". Skipping "
                      << op.type();
         return;
       }
-      if (types_with_independent_output_shape.count(op.type())) {
+      if (types_with_independent_output_shape.count(op.type()) || (bypass_input_check && it == shape_info_.end())) {
         TensorShape input_shape;
         input_shapes.emplace_back(std::move(input_shape));
-
       } else {
         input_shapes.emplace_back(it->second.shape);
       }
     }
 
-    const OpSchema* schema = OpSchemaRegistry::Schema(op.type());
+    // Schema can be pre-defined.
+    // If not predefined, get the schema for the op.
+    if (schema == nullptr) {
+      schema = OpSchemaRegistry::Schema(op.type());
+    }
     CAFFE_ENFORCE(schema);
     std::vector<TensorShape> output_shapes;
     output_shapes = schema->InferTensor(op, input_shapes);
@@ -923,7 +940,7 @@
 
     for (int i = 0; i < output_shapes.size(); i++) {
       const auto& shape = output_shapes[i];
-      if (infered_data_type == TensorProto::UNDEFINED) {
+      if (infered_data_type == TensorProto::UNDEFINED || pruning_ops.find(op.type()) != pruning_ops.end()) {
         infered_data_type = shape.data_type();
       }
       if (shape.unknown_shape()) {
@@ -937,7 +954,8 @@
           is_quantized,
           false,
           scale,
-          offset);
+          offset,
+          in_place_op);
     }
   } catch (const caffe2::EnforceNotMet& e) {
     LOG(ERROR) << "Enforce not met while inferring shapes for " << op.type()
diff --git a/caffe2/opt/bound_shape_inferencer.h b/caffe2/opt/bound_shape_inferencer.h
index cf03478..662121b 100644
--- a/caffe2/opt/bound_shape_inferencer.h
+++ b/caffe2/opt/bound_shape_inferencer.h
@@ -107,7 +107,8 @@
       bool is_quantized,
       bool allow_existing_shape = false,
       float scale = 1,
-      int offset = 0);
+      int offset = 0,
+      bool in_place_op = false);
 
   TensorShape& SetTensorBoundShapeIfNotExist(
       const std::string& name,
@@ -136,7 +137,7 @@
 
   // Standard shape/type inference using op schema registered shape inference
   // function
-  void InferCommonOp(const OperatorDef& op);
+  void InferCommonOp(const OperatorDef& op, const OpSchema* schema = nullptr, bool bypass_input_check = false, bool in_place_op = false);
 
   // Initialize private parameters, such as shape_info, extract_feature_len_
   // This is called at the beginning of InferBoundShapeAndType()