| #include "caffe2/operators/fc_inference.h" |
| |
| namespace caffe2 { |
| std::vector<TensorShape> FCShapeInference( |
| const OperatorDef& def, |
| const vector<TensorShape>& in, |
| bool pretransposed_weight) { |
| vector<TensorShape> out(1); |
| ArgumentHelper helper(def); |
| |
| auto axis = helper.GetSingleArgument<int32_t>("axis", 1); |
| const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size()); |
| auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1); |
| const int canonical_axis_w = |
| canonical_axis_index_(axis_w, in[1].dims().size()); |
| const int N = pretransposed_weight |
| ? size_from_dim_(canonical_axis_w, GetDimsVector(in[1])) |
| : size_to_dim_(canonical_axis_w, GetDimsVector(in[1])); |
| |
| vector<int> y_shape(in[0].dims().begin(), in[0].dims().end()); |
| CAFFE_ENFORCE_LE(canonical_axis + 1, y_shape.size()); |
| y_shape.resize(canonical_axis + 1); |
| y_shape[canonical_axis] = N; |
| out[0] = CreateTensorShape(y_shape, in[0].data_type()); |
| return out; |
| } |
| |
| OpSchema::Cost CostInferenceForFC( |
| const OperatorDef& def, |
| const vector<TensorShape>& in) { |
| CAFFE_ENFORCE_EQ(in.size(), 3, "FC requires three inputs"); |
| struct OpSchema::Cost c; |
| ArgumentHelper helper(def); |
| |
| const auto& X = in[0]; |
| const auto& W = in[1]; |
| const auto& b = in[2]; |
| auto axis = helper.GetSingleArgument<int32_t>("axis", 1); |
| const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size()); |
| const int M = size_to_dim_(canonical_axis, GetDimsVector(in[0])); |
| const int K = size_from_dim_(canonical_axis, GetDimsVector(in[0])); |
| auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1); |
| const int canonical_axis_w = |
| canonical_axis_index_(axis_w, in[1].dims().size()); |
| const int N = size_to_dim_(canonical_axis_w, GetDimsVector(in[1])); |
| |
| uint64_t nElemX = nElemFromDim(X); |
| uint64_t nElemW = nElemFromDim(W); |
| uint64_t nElemB = nElemFromDim(b); |
| c.flops = 2 * K * M * N + M * N; |
| c.bytes_read = (nElemX + nElemW + nElemB) * sizeof(X.data_type()); |
| c.bytes_written = M * N * sizeof(X.data_type()); |
| c.params_bytes = (K * N + N) * sizeof(X.data_type()); |
| return c; |
| } |
| } // namespace caffe2 |