Add cost inference of fwd sparse operators and sparse adagrad (#9314)
Summary:
We should also add cost inference for sparse operators in backward pass later.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9314
Reviewed By: orionr
Differential Revision: D8789240
Pulled By: jspark1105
fbshipit-source-id: 68c2170f294fe13bcc409276f599b5fa8a98bcd3
diff --git a/caffe2/core/operator_schema.h b/caffe2/core/operator_schema.h
index 6fafcf4..d902d04 100644
--- a/caffe2/core/operator_schema.h
+++ b/caffe2/core/operator_schema.h
@@ -271,7 +271,7 @@
OpSchema& Input(const int n, const char* name, const char* description);
OpSchema& Output(const int n, const char* name, const char* description);
// Calls the passed function with `this` as an argument. Useful for
- // adding docs for temlated/macro ops.
+ // adding docs for templated/macro ops.
OpSchema& FillUsing(std::function<void(OpSchema&)> populator);
// Remove from documentation
@@ -478,9 +478,9 @@
}
// Helper function
-inline uint64_t nElemFromDim(const TensorShape& X) {
- uint64_t nElem = X.dims_size() > 0 ? 1 : 0;
- for (int i = 0; i < X.dims_size(); ++i) {
+inline uint64_t nElemFromDim(const TensorShape& X, int dim = 0) {
+ uint64_t nElem = 1;
+ for (int i = dim; i < X.dims_size(); ++i) {
nElem *= X.dims(i);
}
return nElem;
diff --git a/caffe2/operators/segment_reduction_op.cc b/caffe2/operators/segment_reduction_op.cc
index 95d507b..05f265d 100644
--- a/caffe2/operators/segment_reduction_op.cc
+++ b/caffe2/operators/segment_reduction_op.cc
@@ -2,6 +2,47 @@
namespace caffe2 {
+OpSchema::Cost CostInferenceForSparseLengths(
+ const OperatorDef& def,
+ const vector<TensorShape>& inputs,
+ bool use_weight) {
+ int min_num_of_inputs = 3 + use_weight;
+ CAFFE_ENFORCE_GE(
+ inputs.size(),
+ min_num_of_inputs,
+ def.type() + " requires at least " +
+ caffe2::to_string(min_num_of_inputs));
+
+ const TensorShape data = inputs[0];
+ const TensorShape indices = inputs[1 + use_weight];
+ const TensorShape lengths = inputs[2 + use_weight];
+
+ OpSchema::Cost c;
+ CAFFE_ENFORCE_GT(data.dims_size(), 0, "data requires at least 1 dimension");
+ uint64_t N = data.dims(0);
+ if (N == 0) {
+ return c;
+ }
+ uint64_t D = nElemFromDim(data, 1);
+ CAFFE_ENFORCE_GT(
+ lengths.dims_size(), 0, "lengths requires at least 1 dimension");
+ uint64_t M = lengths.dims(0);
+ uint64_t indices_size = nElemFromDim(indices);
+
+ c.flops = indices_size * D;
+ c.bytes_read = indices_size *
+ (D * sizeof(data.data_type()) + sizeof(indices.data_type())) +
+ M * sizeof(lengths.data_type());
+ c.params_bytes = N * D * sizeof(data.data_type());
+ if (use_weight) {
+ const TensorShape weights = inputs[1];
+ c.flops += indices_size * D;
+ c.bytes_read += indices_size * sizeof(weights.data_type());
+ }
+
+ return c;
+}
+
// registering 5 input gradient with main output
// gradient of SparseLengthsWeightedSum
OPERATOR_SCHEMA(SparseLengthsIndicesInGradientWeightedSumWithMainInputGradient)
@@ -296,19 +337,15 @@
string doc = Def::doc;
ReplaceAll(doc, "{op}", Def::OpDef::name);
ReplaceAll(doc, "{op_doc}", Def::OpDef::doc);
- if (strcmp(Def::OpDef::name,"Max") == 0){
+ if (strcmp(Def::OpDef::name, "Max") == 0) {
ReplaceAll(doc, "{extra}", kLengthsMaxExtra);
- }
- else if (strcmp(Def::OpDef::name,"Mean") == 0){
+ } else if (strcmp(Def::OpDef::name, "Mean") == 0) {
ReplaceAll(doc, "{extra}", kLengthsMeanExtra);
- }
- else if (strcmp(Def::OpDef::name,"Sum") == 0){
+ } else if (strcmp(Def::OpDef::name, "Sum") == 0) {
ReplaceAll(doc, "{extra}", kLengthsSumExtra);
- }
- else if (strcmp(Def::OpDef::name,"WeightedSum") == 0){
+ } else if (strcmp(Def::OpDef::name, "WeightedSum") == 0) {
ReplaceAll(doc, "{extra}", kLengthsWeightedSumExtra);
- }
- else{
+ } else {
ReplaceAll(doc, "{extra}", " ");
}
return doc;
diff --git a/caffe2/operators/segment_reduction_op.h b/caffe2/operators/segment_reduction_op.h
index 50b3446..5ef2c3d 100644
--- a/caffe2/operators/segment_reduction_op.h
+++ b/caffe2/operators/segment_reduction_op.h
@@ -1964,6 +1964,11 @@
GradientNeedIndices>;
};
+OpSchema::Cost CostInferenceForSparseLengths(
+ const OperatorDef& def,
+ const vector<TensorShape>& inputs,
+ bool use_weight);
+
template <
typename T,
typename SIndex,
@@ -2015,6 +2020,13 @@
return out;
});
ReducerDef::PopulateSchema(schema);
+
+ schema.CostInferenceFunction(
+ [](const OperatorDef& def,
+ const vector<TensorShape>& inputs) -> OpSchema::Cost {
+ return CostInferenceForSparseLengths(
+ def, inputs, strcmp(OpDef::name, "WeightedSum") == 0);
+ });
}
using Reducer = typename ReducerDef::template Reducer<T, Context>;
using ReducerGradient =
diff --git a/caffe2/sgd/adagrad_op.cc b/caffe2/sgd/adagrad_op.cc
index c963f54..0df4f0c 100644
--- a/caffe2/sgd/adagrad_op.cc
+++ b/caffe2/sgd/adagrad_op.cc
@@ -37,6 +37,33 @@
"Default 1. If it is in (0, 1), the gradient square sum "
"is decayed by this factor.");
+static OpSchema::Cost CostInferenceForSparseAdagrad(
+ const OperatorDef& /* unused */,
+ const vector<TensorShape>& inputs) {
+ CAFFE_ENFORCE_GE(
+ inputs.size(), 4, "SparseAdagrad requires at least 4 inputs");
+
+ const TensorShape param = inputs[0];
+ const TensorShape moment = inputs[1];
+ const TensorShape indices = inputs[2];
+ const TensorShape grad = inputs[3];
+
+ uint64_t n = nElemFromDim(indices);
+ uint64_t grad_size = nElemFromDim(grad);
+
+ OpSchema::Cost c;
+ // See adagrad_op.h (note that decay is 1 for SparseAdagrad).
+ // 2 multiplications, 3 additions, 1 division, and 1 sqrt
+ // (optimistically count sqrt as one flop).
+ c.flops = grad_size * 7;
+ c.bytes_written =
+ grad_size * (sizeof(param.data_type()) + sizeof(moment.data_type()));
+ c.bytes_read = c.bytes_written + grad_size * sizeof(grad.data_type()) +
+ n * sizeof(indices.data_type());
+
+ return c;
+}
+
REGISTER_CPU_OPERATOR(SparseAdagrad, SparseAdagradOp<float, CPUContext>);
OPERATOR_SCHEMA(SparseAdagrad)
.NumInputs(5)
@@ -56,7 +83,9 @@
.Input(4, "lr", "learning rate")
.Output(0, "output_param", "Updated parameters")
.Output(1, "output_moment_1", "Updated moment")
- .Arg("epsilon", "Default 1e-5");
+ .Arg("epsilon", "Default 1e-5")
+ .CostInferenceFunction(
+ OpSchema::CostInferenceFunctionType(CostInferenceForSparseAdagrad));
REGISTER_CPU_OPERATOR(
RowWiseSparseAdagrad,