blob: 5d45f845591013d63b3289591aeebec9613a7a87 [file] [log] [blame]
#ifndef CAFFE2_UTILS_MKL_SGEMM_PACK_H_
#define CAFFE2_UTILS_MKL_SGEMM_PACK_H_
#include "caffe2/core/logging.h"
namespace caffe2 {
namespace mkl {
struct MKLPackedMatrix {
CBLAS_IDENTIFIER identifier_;
CBLAS_TRANSPOSE trans_;
int m_;
int n_;
int k_;
float alpha_;
int ld_;
float* data_ = nullptr;
MKLPackedMatrix(
const CBLAS_IDENTIFIER identifier,
const CBLAS_TRANSPOSE trans,
const int m,
const int n,
const int k,
const float alpha,
const float* src,
const int ld)
: identifier_(identifier),
trans_(trans),
m_(m),
n_(n),
k_(k),
alpha_(alpha),
ld_(ld) {
data_ = cblas_sgemm_alloc(identifier, m, n, k);
CAFFE_ENFORCE(data_, "MKL runtime error: cannot allocate sgemm memory.");
cblas_sgemm_pack(
CblasRowMajor, identifier, trans, m, n, k, alpha, src, ld, data_);
}
~MKLPackedMatrix() {
if (data_) {
cblas_sgemm_free(data_);
}
}
};
} // namespace mkl
} // namespace caffe2
#endif // CAFFE2_UTILS_MKL_SGEMM_PACK_H_