Add cost inference of fwd sparse operators and sparse adagrad (#9314)

Summary:
We should also add cost inference for sparse operators in backward pass later.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9314

Reviewed By: orionr

Differential Revision: D8789240

Pulled By: jspark1105

fbshipit-source-id: 68c2170f294fe13bcc409276f599b5fa8a98bcd3
diff --git a/caffe2/core/operator_schema.h b/caffe2/core/operator_schema.h
index 6fafcf4..d902d04 100644
--- a/caffe2/core/operator_schema.h
+++ b/caffe2/core/operator_schema.h
@@ -271,7 +271,7 @@
   OpSchema& Input(const int n, const char* name, const char* description);
   OpSchema& Output(const int n, const char* name, const char* description);
   // Calls the passed function with `this` as an argument. Useful for
-  // adding docs for temlated/macro ops.
+  // adding docs for templated/macro ops.
   OpSchema& FillUsing(std::function<void(OpSchema&)> populator);
 
   // Remove from documentation
@@ -478,9 +478,9 @@
 }
 
 // Helper function
-inline uint64_t nElemFromDim(const TensorShape& X) {
-  uint64_t nElem = X.dims_size() > 0 ? 1 : 0;
-  for (int i = 0; i < X.dims_size(); ++i) {
+inline uint64_t nElemFromDim(const TensorShape& X, int dim = 0) {
+  uint64_t nElem = 1;
+  for (int i = dim; i < X.dims_size(); ++i) {
     nElem *= X.dims(i);
   }
   return nElem;
diff --git a/caffe2/operators/segment_reduction_op.cc b/caffe2/operators/segment_reduction_op.cc
index 95d507b..05f265d 100644
--- a/caffe2/operators/segment_reduction_op.cc
+++ b/caffe2/operators/segment_reduction_op.cc
@@ -2,6 +2,47 @@
 
 namespace caffe2 {
 
+OpSchema::Cost CostInferenceForSparseLengths(
+    const OperatorDef& def,
+    const vector<TensorShape>& inputs,
+    bool use_weight) {
+  int min_num_of_inputs = 3 + use_weight;
+  CAFFE_ENFORCE_GE(
+      inputs.size(),
+      min_num_of_inputs,
+      def.type() + " requires at least " +
+          caffe2::to_string(min_num_of_inputs));
+
+  const TensorShape data = inputs[0];
+  const TensorShape indices = inputs[1 + use_weight];
+  const TensorShape lengths = inputs[2 + use_weight];
+
+  OpSchema::Cost c;
+  CAFFE_ENFORCE_GT(data.dims_size(), 0, "data requires at least 1 dimension");
+  uint64_t N = data.dims(0);
+  if (N == 0) {
+    return c;
+  }
+  uint64_t D = nElemFromDim(data, 1);
+  CAFFE_ENFORCE_GT(
+      lengths.dims_size(), 0, "lengths requires at least 1 dimension");
+  uint64_t M = lengths.dims(0);
+  uint64_t indices_size = nElemFromDim(indices);
+
+  c.flops = indices_size * D;
+  c.bytes_read = indices_size *
+          (D * sizeof(data.data_type()) + sizeof(indices.data_type())) +
+      M * sizeof(lengths.data_type());
+  c.params_bytes = N * D * sizeof(data.data_type());
+  if (use_weight) {
+    const TensorShape weights = inputs[1];
+    c.flops += indices_size * D;
+    c.bytes_read += indices_size * sizeof(weights.data_type());
+  }
+
+  return c;
+}
+
 // registering 5 input gradient with main output
 // gradient of SparseLengthsWeightedSum
 OPERATOR_SCHEMA(SparseLengthsIndicesInGradientWeightedSumWithMainInputGradient)
@@ -296,19 +337,15 @@
   string doc = Def::doc;
   ReplaceAll(doc, "{op}", Def::OpDef::name);
   ReplaceAll(doc, "{op_doc}", Def::OpDef::doc);
-  if (strcmp(Def::OpDef::name,"Max") == 0){
+  if (strcmp(Def::OpDef::name, "Max") == 0) {
     ReplaceAll(doc, "{extra}", kLengthsMaxExtra);
-  }
-  else if (strcmp(Def::OpDef::name,"Mean") == 0){
+  } else if (strcmp(Def::OpDef::name, "Mean") == 0) {
     ReplaceAll(doc, "{extra}", kLengthsMeanExtra);
-  }
-  else if (strcmp(Def::OpDef::name,"Sum") == 0){
+  } else if (strcmp(Def::OpDef::name, "Sum") == 0) {
     ReplaceAll(doc, "{extra}", kLengthsSumExtra);
-  }
-  else if (strcmp(Def::OpDef::name,"WeightedSum") == 0){
+  } else if (strcmp(Def::OpDef::name, "WeightedSum") == 0) {
     ReplaceAll(doc, "{extra}", kLengthsWeightedSumExtra);
-  }
-  else{
+  } else {
     ReplaceAll(doc, "{extra}", " ");
   }
   return doc;
diff --git a/caffe2/operators/segment_reduction_op.h b/caffe2/operators/segment_reduction_op.h
index 50b3446..5ef2c3d 100644
--- a/caffe2/operators/segment_reduction_op.h
+++ b/caffe2/operators/segment_reduction_op.h
@@ -1964,6 +1964,11 @@
       GradientNeedIndices>;
 };
 
+OpSchema::Cost CostInferenceForSparseLengths(
+    const OperatorDef& def,
+    const vector<TensorShape>& inputs,
+    bool use_weight);
+
 template <
     typename T,
     typename SIndex,
@@ -2015,6 +2020,13 @@
           return out;
         });
     ReducerDef::PopulateSchema(schema);
+
+    schema.CostInferenceFunction(
+        [](const OperatorDef& def,
+           const vector<TensorShape>& inputs) -> OpSchema::Cost {
+          return CostInferenceForSparseLengths(
+              def, inputs, strcmp(OpDef::name, "WeightedSum") == 0);
+        });
   }
   using Reducer = typename ReducerDef::template Reducer<T, Context>;
   using ReducerGradient =
diff --git a/caffe2/sgd/adagrad_op.cc b/caffe2/sgd/adagrad_op.cc
index c963f54..0df4f0c 100644
--- a/caffe2/sgd/adagrad_op.cc
+++ b/caffe2/sgd/adagrad_op.cc
@@ -37,6 +37,33 @@
         "Default 1. If it is in (0, 1), the gradient square sum "
         "is decayed by this factor.");
 
+static OpSchema::Cost CostInferenceForSparseAdagrad(
+    const OperatorDef& /* unused */,
+    const vector<TensorShape>& inputs) {
+  CAFFE_ENFORCE_GE(
+      inputs.size(), 4, "SparseAdagrad requires at least 4 inputs");
+
+  const TensorShape param = inputs[0];
+  const TensorShape moment = inputs[1];
+  const TensorShape indices = inputs[2];
+  const TensorShape grad = inputs[3];
+
+  uint64_t n = nElemFromDim(indices);
+  uint64_t grad_size = nElemFromDim(grad);
+
+  OpSchema::Cost c;
+  // See adagrad_op.h (note that decay is 1 for SparseAdagrad).
+  // 2 multiplications, 3 additions, 1 division, and 1 sqrt
+  // (optimistically count sqrt as one flop).
+  c.flops = grad_size * 7;
+  c.bytes_written =
+      grad_size * (sizeof(param.data_type()) + sizeof(moment.data_type()));
+  c.bytes_read = c.bytes_written + grad_size * sizeof(grad.data_type()) +
+      n * sizeof(indices.data_type());
+
+  return c;
+}
+
 REGISTER_CPU_OPERATOR(SparseAdagrad, SparseAdagradOp<float, CPUContext>);
 OPERATOR_SCHEMA(SparseAdagrad)
     .NumInputs(5)
@@ -56,7 +83,9 @@
     .Input(4, "lr", "learning rate")
     .Output(0, "output_param", "Updated parameters")
     .Output(1, "output_moment_1", "Updated moment")
-    .Arg("epsilon", "Default 1e-5");
+    .Arg("epsilon", "Default 1e-5")
+    .CostInferenceFunction(
+        OpSchema::CostInferenceFunctionType(CostInferenceForSparseAdagrad));
 
 REGISTER_CPU_OPERATOR(
     RowWiseSparseAdagrad,