blob: 26fb02bac629adeda54dd20ad7139d30e6011cfa [file] [log] [blame]
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "caffe2/core/blob.h"
#include "caffe2/core/blob_serialization.h"
#include "caffe2/core/common.h"
#include "caffe2/mkl/mkl_utils.h"
#include <gtest/gtest.h>
#ifdef CAFFE2_HAS_MKL_DNN
namespace caffe2 {
using mkl::MKLMemory;
TEST(MKLTest, MKLMemorySerialization) {
Blob blob;
vector<int> shape{2, 3, 4};
float data[2 * 3 * 4];
for (int i = 0; i < 2 * 3 * 4; ++i) {
data[i] = i;
}
blob.Reset<MKLMemory<float>>(new MKLMemory<float>(shape));
MKLMemory<float>* mkl_memory = blob.GetMutable<MKLMemory<float>>();
mkl_memory->CopyFrom(data);
string serialized = blob.Serialize("test");
BlobProto proto;
CHECK(proto.ParseFromString(serialized));
EXPECT_EQ(proto.name(), "test");
EXPECT_EQ(proto.type(), "Tensor");
EXPECT_TRUE(proto.has_tensor());
const TensorProto& tensor_proto = proto.tensor();
EXPECT_EQ(
tensor_proto.data_type(), TypeMetaToDataType(TypeMeta::Make<float>()));
EXPECT_EQ(tensor_proto.float_data_size(), 2 * 3 * 4);
for (int i = 0; i < 2 * 3 * 4; ++i) {
EXPECT_EQ(tensor_proto.float_data(i), static_cast<float>(i));
}
Blob new_blob;
EXPECT_NO_THROW(new_blob.Deserialize(serialized));
EXPECT_TRUE(new_blob.IsType<MKLMemory<float>>());
const auto& new_mkl_memory = blob.Get<MKLMemory<float>>();
EXPECT_EQ(new_mkl_memory.dims().size(), 3);
EXPECT_EQ(new_mkl_memory.dims()[0], 2);
EXPECT_EQ(new_mkl_memory.dims()[1], 3);
EXPECT_EQ(new_mkl_memory.dims()[2], 4);
float recovered_data[2 * 3 * 4];
new_mkl_memory.CopyTo(recovered_data);
for (int i = 0; i < 2 * 3 * 4; ++i) {
EXPECT_EQ(recovered_data[i], i);
}
}
} // namespace caffe2
#endif // CAFFE2_HAS_MKL_DNN