blob: 2c11e8a72bfbae942829228ac84b8dece9bded57 [file] [log] [blame]
#include "caffe2/operators/fc_inference.h"
namespace caffe2 {
std::vector<TensorShape> FCShapeInference(
const OperatorDef& def,
const vector<TensorShape>& in,
bool pretransposed_weight) {
vector<TensorShape> out(1);
ArgumentHelper helper(def);
auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size());
auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1);
const int canonical_axis_w =
canonical_axis_index_(axis_w, in[1].dims().size());
const int N = pretransposed_weight
? size_from_dim_(canonical_axis_w, GetDimsVector(in[1]))
: size_to_dim_(canonical_axis_w, GetDimsVector(in[1]));
vector<int> y_shape(in[0].dims().begin(), in[0].dims().end());
CAFFE_ENFORCE_LE(canonical_axis + 1, y_shape.size());
y_shape.resize(canonical_axis + 1);
y_shape[canonical_axis] = N;
out[0] = CreateTensorShape(y_shape, in[0].data_type());
return out;
}
OpSchema::Cost CostInferenceForFC(
const OperatorDef& def,
const vector<TensorShape>& in) {
CAFFE_ENFORCE_EQ(in.size(), 3, "FC requires three inputs");
struct OpSchema::Cost c;
ArgumentHelper helper(def);
const auto& X = in[0];
const auto& W = in[1];
const auto& b = in[2];
auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size());
const int M = size_to_dim_(canonical_axis, GetDimsVector(in[0]));
const int K = size_from_dim_(canonical_axis, GetDimsVector(in[0]));
auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1);
const int canonical_axis_w =
canonical_axis_index_(axis_w, in[1].dims().size());
const int N = size_to_dim_(canonical_axis_w, GetDimsVector(in[1]));
uint64_t nElemX = nElemFromDim(X);
uint64_t nElemW = nElemFromDim(W);
uint64_t nElemB = nElemFromDim(b);
c.flops = 2 * K * M * N + M * N;
c.bytes_read = (nElemX + nElemW + nElemB) * sizeof(X.data_type());
c.bytes_written = M * N * sizeof(X.data_type());
c.params_bytes = (K * N + N) * sizeof(X.data_type());
return c;
}
} // namespace caffe2