blob: 35f787b8f6f6192a15f1f82b2d2115dec3e173d2 [file] [log] [blame]
#include "caffe2/core/blob.h"
#include "caffe2/core/blob_serialization.h"
#include "caffe2/mkl/mkl_utils.h"
#ifdef CAFFE2_HAS_MKL_DNN
namespace caffe2 {
namespace mkl {
/**
* @brief MKLMemorySerializer is the serializer for MKLMemory.
*
* MKLMemorySerializer takes in a blob that contains an MKLMemory, and
* serializes it into a TensorProto protocol buffer.
*/
class MKLMemorySerializer : public BlobSerializerBase {
public:
MKLMemorySerializer() {}
~MKLMemorySerializer() {}
void Serialize(
const Blob& blob,
const string& name,
SerializationAcceptor acceptor) override {
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type(kTensorBlobType);
TensorProto* proto = blob_proto.mutable_tensor();
auto* device_detail = proto->mutable_device_detail();
device_detail->set_device_type(MKLDNN);
proto->set_name(name);
if (blob.IsType<MKLMemory<float>>()) {
const MKLMemory<float>& src = blob.Get<MKLMemory<float>>();
CAFFE_ENFORCE(
src.buffer(), "Cannot serialize an empty MKLMemory object.");
size_t total = 1;
for (int i = 0; i < src.dims().size(); ++i) {
proto->add_dims(src.dims()[i]);
total *= src.dims()[i];
}
proto->mutable_float_data()->Reserve(total);
while (total--) {
proto->add_float_data(0);
}
src.CopyTo(proto->mutable_float_data()->mutable_data());
} else if (blob.IsType<MKLMemory<double>>()) {
const MKLMemory<double>& src = blob.Get<MKLMemory<double>>();
CAFFE_ENFORCE(
src.buffer(), "Cannot serialize an empty MKLMemory object.");
size_t total = 1;
for (int i = 0; i < src.dims().size(); ++i) {
proto->add_dims(src.dims()[i]);
total *= src.dims()[i];
}
proto->mutable_double_data()->Reserve(total);
while (total--) {
proto->add_double_data(0);
}
src.CopyTo(proto->mutable_double_data()->mutable_data());
} else {
CAFFE_THROW(
"MKLMemory could only be either float or double. "
"Encountered unsupported type.");
}
acceptor(name, blob_proto.SerializeAsString());
}
};
/**
* @brief MKLMemoryDeserializer is the deserializer for TensorProto that has
* MKLDNN as its device.
*
* The device that the deserialized Tensor will live under is determined by the
* device_detail field. If you want to specify the device of the deserialized
* tensor, change the TensorProto's corresponding fields before calling
* Deserialize.
*/
class MKLMemoryDeserializer : public BlobDeserializerBase {
public:
void Deserialize(const BlobProto& blob_proto, Blob* blob) override {
const TensorProto& proto = blob_proto.tensor();
CAFFE_ENFORCE(
proto.data_type() == TensorProto_DataType_FLOAT ||
proto.data_type() == TensorProto_DataType_DOUBLE,
"MKLMemory only supports either float or double formats.");
CAFFE_ENFORCE(
!proto.has_segment(), "MKLMemory does not support segment right now.");
vector<TIndex> dims;
for (const TIndex d : proto.dims()) {
dims.push_back(d);
}
// TODO: right now, every time we do a deserializer we create a new MKL
// Memory object. Optionally, we can change that.
switch (proto.data_type()) {
case TensorProto_DataType_FLOAT: {
auto dst = make_unique<MKLMemory<float>>(dims);
dst->CopyFrom(proto.float_data().data());
blob->Reset(dst.release());
break;
}
case TensorProto_DataType_DOUBLE: {
auto dst = make_unique<MKLMemory<double>>(dims);
dst->CopyFrom(proto.double_data().data());
blob->Reset(dst.release());
break;
}
default:
CAFFE_THROW("This should not happen, we guarded things above already.");
}
}
};
} // namespace mkl
REGISTER_BLOB_SERIALIZER(
(TypeMeta::Id<mkl::MKLMemory<float>>()),
mkl::MKLMemorySerializer);
REGISTER_BLOB_SERIALIZER(
(TypeMeta::Id<mkl::MKLMemory<double>>()),
mkl::MKLMemorySerializer);
REGISTER_BLOB_DESERIALIZER(TensorMKLDNN, mkl::MKLMemoryDeserializer);
} // namespace caffe2
#endif // CAFFE2_HAS_MKL_DNN