| #include <array> |
| #include <cstdio> |
| #include <cstring> |
| #include <string> |
| |
| #include <gtest/gtest.h> |
| |
| #include "caffe2/serialize/inline_container.h" |
| #include <c10/util/Logging.h> |
| #include "c10/util/irange.h" |
| |
| namespace caffe2 { |
| namespace serialize { |
| namespace { |
| |
| TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { |
| int64_t kFieldAlignment = 64L; |
| |
| std::ostringstream oss; |
| // write records through writers |
| PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t { |
| oss.write(static_cast<const char*>(b), n); |
| return oss ? n : 0; |
| }); |
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) |
| std::array<char, 127> data1; |
| // Inplace memory buffer |
| std::vector<uint8_t> buf(data1.size()); |
| |
| for (auto i : c10::irange(data1.size())) { |
| data1[i] = data1.size() - i; |
| } |
| writer.writeRecord("key1", data1.data(), data1.size()); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) |
| std::array<char, 64> data2; |
| for (auto i : c10::irange(data2.size())) { |
| data2[i] = data2.size() - i; |
| } |
| writer.writeRecord("key2", data2.data(), data2.size()); |
| |
| const std::unordered_set<std::string>& written_records = |
| writer.getAllWrittenRecords(); |
| ASSERT_EQ(written_records.size(), 2); |
| ASSERT_EQ(written_records.count("key1"), 1); |
| ASSERT_EQ(written_records.count("key2"), 1); |
| |
| writer.writeEndOfFile(); |
| ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1); |
| |
| std::string the_file = oss.str(); |
| const char* file_name = "output.zip"; |
| std::ofstream foo(file_name); |
| foo.write(the_file.c_str(), the_file.size()); |
| foo.close(); |
| |
| std::istringstream iss(the_file); |
| |
| // read records through readers |
| PyTorchStreamReader reader(&iss); |
| ASSERT_TRUE(reader.hasRecord("key1")); |
| ASSERT_TRUE(reader.hasRecord("key2")); |
| ASSERT_FALSE(reader.hasRecord("key2000")); |
| at::DataPtr data_ptr; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| int64_t size; |
| std::tie(data_ptr, size) = reader.getRecord("key1"); |
| size_t off1 = reader.getRecordOffset("key1"); |
| ASSERT_EQ(size, data1.size()); |
| ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), data1.size()), 0); |
| ASSERT_EQ(memcmp(the_file.c_str() + off1, data1.data(), data1.size()), 0); |
| ASSERT_EQ(off1 % kFieldAlignment, 0); |
| // inplace getRecord() test |
| std::vector<uint8_t> dst(size); |
| size_t ret = reader.getRecord("key1", dst.data(), size); |
| ASSERT_EQ(ret, size); |
| ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0); |
| // chunked getRecord() test |
| ret = reader.getRecord( |
| "key1", dst.data(), size, 3, buf.data(), [](void* dst, const void* src, size_t n) { |
| memcpy(dst, src, n); |
| }); |
| ASSERT_EQ(ret, size); |
| ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0); |
| |
| std::tie(data_ptr, size) = reader.getRecord("key2"); |
| size_t off2 = reader.getRecordOffset("key2"); |
| ASSERT_EQ(off2 % kFieldAlignment, 0); |
| |
| ASSERT_EQ(size, data2.size()); |
| ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), data2.size()), 0); |
| ASSERT_EQ(memcmp(the_file.c_str() + off2, data2.data(), data2.size()), 0); |
| // inplace getRecord() test |
| dst.resize(size); |
| ret = reader.getRecord("key2", dst.data(), size); |
| ASSERT_EQ(ret, size); |
| ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0); |
| // chunked getRecord() test |
| ret = reader.getRecord( |
| "key2", dst.data(), size, 3, buf.data(), [](void* dst, const void* src, size_t n) { |
| memcpy(dst, src, n); |
| }); |
| ASSERT_EQ(ret, size); |
| ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0); |
| // clean up |
| remove(file_name); |
| } |
| |
| TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) { |
| |
| std::ostringstream oss; |
| // write records through writers |
| PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t { |
| oss.write(static_cast<const char*>(b), n); |
| return oss ? n : 0; |
| }); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) |
| std::array<char, 127> data1; |
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) |
| std::array<char, 64> data2; |
| for (auto i : c10::irange(data1.size())) { |
| data1[i] = data1.size() - i; |
| } |
| writer.writeRecord("key1", data1.data(), data1.size()); |
| |
| for (auto i : c10::irange(data2.size())) { |
| data2[i] = data2.size() - i; |
| } |
| writer.writeRecord("key2", data2.data(), data2.size()); |
| |
| const std::unordered_set<std::string>& written_records = |
| writer.getAllWrittenRecords(); |
| ASSERT_EQ(written_records.size(), 2); |
| ASSERT_EQ(written_records.count("key1"), 1); |
| ASSERT_EQ(written_records.count("key2"), 1); |
| |
| writer.writeEndOfFile(); |
| ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1); |
| |
| std::string the_file = oss.str(); |
| const char* file_name = "output.zip"; |
| std::ofstream foo(file_name); |
| foo.write(the_file.c_str(), the_file.size()); |
| foo.close(); |
| |
| // read records through pytorchStreamReader |
| std::istringstream iss(the_file); |
| PyTorchStreamReader reader(&iss); |
| reader.setAdditionalReaderSizeThreshold(0); |
| // before testing, sanity check |
| int64_t size1, size2, ret; |
| at::DataPtr data_ptr; |
| std::tie(data_ptr, size1) = reader.getRecord("key1"); |
| std::tie(data_ptr, size2) = reader.getRecord("key2"); |
| |
| // Test getRecord(name, additional_readers) |
| std::vector<std::shared_ptr<ReadAdapterInterface>> additionalReader; |
| for(int i=0; i<10; ++i){ |
| // Test various sized additional readers. |
| std::tie(data_ptr, ret) = reader.getRecord("key1", additionalReader); |
| ASSERT_EQ(ret, size1); |
| ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), size1), 0); |
| |
| std::tie(data_ptr, ret) = reader.getRecord("key2", additionalReader); |
| ASSERT_EQ(ret, size2); |
| ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), size2), 0); |
| } |
| |
| // Inplace multi-threading getRecord(name, dst, n, additional_readers) test |
| additionalReader.clear(); |
| std::vector<uint8_t> dst1(size1), dst2(size2); |
| for(int i=0; i<10; ++i){ |
| // Test various sizes of read threads |
| additionalReader.push_back(std::make_unique<IStreamAdapter>(&iss)); |
| |
| ret = reader.getRecord("key1", dst1.data(), size1, additionalReader); |
| ASSERT_EQ(ret, size1); |
| ASSERT_EQ(memcmp(dst1.data(), data1.data(), size1), 0); |
| |
| ret = reader.getRecord("key2", dst2.data(), size2, additionalReader); |
| ASSERT_EQ(ret, size2); |
| ASSERT_EQ(memcmp(dst2.data(), data2.data(), size2), 0); |
| } |
| // clean up |
| remove(file_name); |
| } |
| |
| TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) { |
| std::ostringstream oss; |
| // write records through writers |
| PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t { |
| oss.write(static_cast<const char*>(b), n); |
| return oss ? n : 0; |
| }); |
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) |
| std::array<char, 127> data1; |
| |
| // Inplace memory buffer |
| std::vector<uint8_t> buf; |
| |
| for (auto i : c10::irange(data1.size())) { |
| data1[i] = data1.size() - i; |
| } |
| writer.writeRecord("key1", data1.data(), data1.size()); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) |
| std::array<char, 64> data2; |
| for (auto i : c10::irange(data2.size())) { |
| data2[i] = data2.size() - i; |
| } |
| writer.writeRecord("key2", data2.data(), data2.size()); |
| |
| const std::unordered_set<std::string>& written_records = |
| writer.getAllWrittenRecords(); |
| ASSERT_EQ(written_records.size(), 2); |
| ASSERT_EQ(written_records.count("key1"), 1); |
| ASSERT_EQ(written_records.count("key2"), 1); |
| |
| writer.writeEndOfFile(); |
| ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1); |
| |
| std::string the_file = oss.str(); |
| const char* file_name = "output2.zip"; |
| std::ofstream foo(file_name); |
| foo.write(the_file.c_str(), the_file.size()); |
| foo.close(); |
| |
| std::istringstream iss(the_file); |
| |
| // read records through readers |
| PyTorchStreamReader reader(&iss); |
| // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
| EXPECT_THROW(reader.getRecord("key3"), c10::Error); |
| std::vector<uint8_t> dst(data1.size()); |
| EXPECT_THROW(reader.getRecord("key3", dst.data(), data1.size()), c10::Error); |
| EXPECT_THROW( |
| reader.getRecord( |
| "key3", |
| dst.data(), |
| data1.size(), |
| 3, |
| buf.data(), |
| [](void* dst, const void* src, size_t n) { memcpy(dst, src, n); }), |
| c10::Error); |
| |
| // Reader should still work after throwing |
| EXPECT_TRUE(reader.hasRecord("key1")); |
| // clean up |
| remove(file_name); |
| } |
| |
| TEST(PytorchStreamWriterAndReader, SkipDebugRecords) { |
| std::ostringstream oss; |
| PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t { |
| oss.write(static_cast<const char*>(b), n); |
| return oss ? n : 0; |
| }); |
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) |
| std::array<char, 127> data1; |
| // Inplace memory buffer |
| std::vector<uint8_t> buf(data1.size()); |
| |
| for (auto i : c10::irange(data1.size())) { |
| data1[i] = data1.size() - i; |
| } |
| writer.writeRecord("key1.debug_pkl", data1.data(), data1.size()); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) |
| std::array<char, 64> data2; |
| for (auto i : c10::irange(data2.size())) { |
| data2[i] = data2.size() - i; |
| } |
| writer.writeRecord("key2.debug_pkl", data2.data(), data2.size()); |
| |
| const std::unordered_set<std::string>& written_records = |
| writer.getAllWrittenRecords(); |
| ASSERT_EQ(written_records.size(), 2); |
| ASSERT_EQ(written_records.count("key1.debug_pkl"), 1); |
| ASSERT_EQ(written_records.count("key2.debug_pkl"), 1); |
| writer.writeEndOfFile(); |
| ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1); |
| |
| std::string the_file = oss.str(); |
| const char* file_name = "output3.zip"; |
| std::ofstream foo(file_name); |
| foo.write(the_file.c_str(), the_file.size()); |
| foo.close(); |
| |
| std::istringstream iss(the_file); |
| |
| // read records through readers |
| PyTorchStreamReader reader(&iss); |
| // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
| |
| reader.setShouldLoadDebugSymbol(false); |
| EXPECT_FALSE(reader.hasRecord("key1.debug_pkl")); |
| at::DataPtr ptr; |
| size_t size; |
| std::tie(ptr, size) = reader.getRecord("key1.debug_pkl"); |
| EXPECT_EQ(size, 0); |
| std::vector<uint8_t> dst(data1.size()); |
| size_t ret = reader.getRecord("key1.debug_pkl", dst.data(), data1.size()); |
| EXPECT_EQ(ret, 0); |
| ret = reader.getRecord( |
| "key1.debug_pkl", |
| dst.data(), |
| data1.size(), |
| 3, |
| buf.data(), |
| [](void* dst, const void* src, size_t n) { memcpy(dst, src, n); }); |
| EXPECT_EQ(ret, 0); |
| // clean up |
| remove(file_name); |
| } |
| |
| TEST(PytorchStreamWriterAndReader, ValidSerializationId) { |
| std::ostringstream oss; |
| PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t { |
| oss.write(static_cast<const char*>(b), n); |
| return oss ? n : 0; |
| }); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) |
| std::array<char, 127> data1; |
| |
| for (auto i: c10::irange(data1.size())) { |
| data1[i] = data1.size() - i; |
| } |
| writer.writeRecord("key1.debug_pkl", data1.data(), data1.size()); |
| writer.writeEndOfFile(); |
| auto writer_serialization_id = writer.serializationId(); |
| |
| std::string the_file = oss.str(); |
| |
| std::istringstream iss(the_file); |
| |
| // read records through readers |
| PyTorchStreamReader reader(&iss); |
| // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
| |
| EXPECT_EQ(reader.serializationId(), writer_serialization_id); |
| |
| // write a second time |
| PyTorchStreamWriter writer2([&](const void* b, size_t n) -> size_t { |
| oss.write(static_cast<const char*>(b), n); |
| return oss ? n : 0; |
| }); |
| writer2.writeRecord("key1.debug_pkl", data1.data(), data1.size()); |
| writer2.writeEndOfFile(); |
| auto writer2_serialization_id = writer2.serializationId(); |
| |
| EXPECT_EQ(writer_serialization_id, writer2_serialization_id); |
| } |
| |
| TEST(PytorchStreamWriterAndReader, SkipDuplicateSerializationIdRecords) { |
| std::ostringstream oss; |
| PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t { |
| oss.write(static_cast<const char*>(b), n); |
| return oss ? n : 0; |
| }); |
| |
| std::string dup_serialization_id = "dup-serialization-id"; |
| writer.writeRecord(kSerializationIdRecordName, dup_serialization_id.c_str(), dup_serialization_id.size()); |
| |
| const std::unordered_set<std::string>& written_records = |
| writer.getAllWrittenRecords(); |
| ASSERT_EQ(written_records.size(), 0); |
| writer.writeEndOfFile(); |
| ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1); |
| auto writer_serialization_id = writer.serializationId(); |
| |
| std::string the_file = oss.str(); |
| const char* file_name = "output4.zip"; |
| std::ofstream foo(file_name); |
| foo.write(the_file.c_str(), the_file.size()); |
| foo.close(); |
| |
| std::istringstream iss(the_file); |
| |
| // read records through readers |
| PyTorchStreamReader reader(&iss); |
| // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
| |
| EXPECT_EQ(reader.serializationId(), writer_serialization_id); |
| // clean up |
| remove(file_name); |
| } |
| |
| TEST(PytorchStreamWriterAndReader, LogAPIUsageMetadata) { |
| std::map<std::string, std::map<std::string, std::string>> logs; |
| |
| SetAPIUsageMetadataLogger( |
| [&](const std::string& context, |
| const std::map<std::string, std::string>& metadata_map) { |
| logs.insert({context, metadata_map}); |
| }); |
| std::ostringstream oss; |
| PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t { |
| oss.write(static_cast<const char*>(b), n); |
| return oss ? n : 0; |
| }); |
| writer.writeEndOfFile(); |
| |
| std::istringstream iss(oss.str()); |
| // read records through readers |
| PyTorchStreamReader reader(&iss); |
| |
| ASSERT_EQ(logs.size(), 2); |
| std::map<std::string, std::map<std::string, std::string>> expected_logs = { |
| {"pytorch.stream.writer.metadata", |
| {{"serialization_id", writer.serializationId()}, |
| {"file_name", "archive"}, |
| {"file_size", str(oss.str().length())}}}, |
| {"pytorch.stream.reader.metadata", |
| {{"serialization_id", writer.serializationId()}, |
| {"file_name", "archive"}, |
| {"file_size", str(iss.str().length())}}} |
| }; |
| ASSERT_EQ(expected_logs, logs); |
| |
| // reset logger |
| SetAPIUsageMetadataLogger( |
| [&](const std::string& context, |
| const std::map<std::string, std::string>& metadata_map) {}); |
| } |
| |
| class ChunkRecordIteratorTest : public ::testing::TestWithParam<int64_t> {}; |
| INSTANTIATE_TEST_SUITE_P( |
| ChunkRecordIteratorTestGroup, |
| ChunkRecordIteratorTest, |
| testing::Values(100, 150, 1010)); |
| |
| TEST_P(ChunkRecordIteratorTest, ChunkRead) { |
| auto chunkSize = GetParam(); |
| std::string zipFileName = "output_chunk_" + std::to_string(chunkSize) + ".zip"; |
| const char* fileName = zipFileName.c_str(); |
| const std::string recordName = "key1"; |
| const size_t tensorDataSizeInBytes = 1000; |
| |
| // write records through writers |
| std::ostringstream oss(std::ios::binary); |
| PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t { |
| oss.write(static_cast<const char*>(b), n); |
| return oss ? n : 0; |
| }); |
| |
| auto tensorData = std::vector<uint8_t>(tensorDataSizeInBytes, 1); |
| auto dataPtr = tensorData.data(); |
| writer.writeRecord(recordName, dataPtr, tensorDataSizeInBytes); |
| const std::unordered_set<std::string>& written_records = |
| writer.getAllWrittenRecords(); |
| ASSERT_EQ(written_records.size(), 1); |
| ASSERT_EQ(written_records.count(recordName), 1); |
| writer.writeEndOfFile(); |
| ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1); |
| |
| std::string the_file = oss.str(); |
| std::ofstream foo(fileName, std::ios::binary); |
| foo.write(the_file.c_str(), the_file.size()); |
| foo.close(); |
| LOG(INFO) << "Finished saving tensor into zip file " << fileName; |
| |
| LOG(INFO) << "Testing chunk size " << chunkSize; |
| PyTorchStreamReader reader(fileName); |
| ASSERT_TRUE(reader.hasRecord(recordName)); |
| auto chunkIterator = reader.createChunkReaderIter( |
| recordName, tensorDataSizeInBytes, chunkSize); |
| std::vector<uint8_t> buffer(chunkSize); |
| size_t totalReadSize = 0; |
| while (auto readSize = chunkIterator.next(buffer.data())) { |
| auto expectedData = std::vector<uint8_t>(readSize, 1); |
| ASSERT_EQ(memcmp(expectedData.data(), buffer.data(), readSize), 0); |
| totalReadSize += readSize; |
| } |
| ASSERT_EQ(totalReadSize, tensorDataSizeInBytes); |
| // clean up |
| remove(fileName); |
| } |
| |
| } // namespace |
| } // namespace serialize |
| } // namespace caffe2 |