blob: 512b4e6c9d789d89da4d555f4c5c002f4d3666f2 [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/tensor_coding.h"
#include <vector>
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/protobuf.h"
#if defined(TENSORFLOW_PROTOBUF_USES_CORD)
#include "strings/cord_varint.h"
#endif // defined(TENSORFLOW_PROTOBUF_USES_CORD)
namespace tensorflow {
namespace port {
void AssignRefCounted(StringPiece src, core::RefCounted* obj, string* out) {
out->assign(src.data(), src.size());
}
void EncodeStringList(const tstring* strings, int64 n, string* out) {
out->clear();
for (int i = 0; i < n; ++i) {
core::PutVarint32(out, strings[i].size());
}
for (int i = 0; i < n; ++i) {
out->append(strings[i]);
}
}
bool DecodeStringList(const string& src, tstring* strings, int64 n) {
std::vector<uint32> sizes(n);
StringPiece reader(src);
int64 tot = 0;
for (auto& v : sizes) {
if (!core::GetVarint32(&reader, &v)) return false;
tot += v;
}
if (tot != static_cast<int64>(reader.size())) {
return false;
}
tstring* data = strings;
for (int64 i = 0; i < n; ++i, ++data) {
auto size = sizes[i];
if (size > reader.size()) {
return false;
}
data->assign(reader.data(), size);
reader.remove_prefix(size);
}
return true;
}
void CopyFromArray(string* s, const char* base, size_t bytes) {
s->assign(base, bytes);
}
class StringListEncoderImpl : public StringListEncoder {
public:
explicit StringListEncoderImpl(string* out) : out_(out) {}
~StringListEncoderImpl() override = default;
void Append(const protobuf::MessageLite& m) override {
core::PutVarint32(out_, m.ByteSizeLong());
tensorflow::string serialized_message;
m.AppendToString(&serialized_message);
strings::StrAppend(&rest_, serialized_message);
}
void Append(const string& s) override {
core::PutVarint32(out_, s.length());
strings::StrAppend(&rest_, s);
}
void Finalize() override { strings::StrAppend(out_, rest_); }
private:
string* out_;
string rest_;
};
class StringListDecoderImpl : public StringListDecoder {
public:
explicit StringListDecoderImpl(const string& in) : reader_(in) {}
~StringListDecoderImpl() override = default;
bool ReadSizes(std::vector<uint32>* sizes) override {
int64 total = 0;
for (auto& size : *sizes) {
if (!core::GetVarint32(&reader_, &size)) return false;
total += size;
}
if (total != static_cast<int64>(reader_.size())) {
return false;
}
return true;
}
const char* Data(uint32 size) override {
const char* data = reader_.data();
reader_.remove_prefix(size);
return data;
}
private:
StringPiece reader_;
};
std::unique_ptr<StringListEncoder> NewStringListEncoder(string* out) {
return std::unique_ptr<StringListEncoder>(new StringListEncoderImpl(out));
}
std::unique_ptr<StringListDecoder> NewStringListDecoder(const string& in) {
return std::unique_ptr<StringListDecoder>(new StringListDecoderImpl(in));
}
#if defined(TENSORFLOW_PROTOBUF_USES_CORD)
void AssignRefCounted(StringPiece src, core::RefCounted* obj, Cord* out) {
obj->Ref();
out->Clear();
// Defines a lambda to unref "obj" when Cord deletes this piece of
// memory. +[] converts the lambda to a C style function pointer.
auto cleanup = +[](absl::string_view donotcare, void* obj) {
reinterpret_cast<core::RefCounted*>(obj)->Unref();
};
out->AppendExternalMemory(absl::string_view(src.data(), src.size()), obj,
cleanup);
}
void EncodeStringList(const tstring* strings, int64 n, Cord* out) {
out->Clear();
for (int i = 0; i < n; ++i) {
::strings::CordAppendVarint(strings[i].size(), out);
}
for (int i = 0; i < n; ++i) {
out->Append(strings[i]);
}
}
bool DecodeStringList(const Cord& src, tstring* strings, int64 n) {
std::vector<uint32> sizes(n);
CordReader reader(src);
int64 tot = 0;
for (auto& v : sizes) {
if (!::strings::CordReaderReadVarint(&reader, &v)) return false;
tot += v;
}
if (tot != reader.Available()) {
return false;
}
tstring* data = strings;
for (int i = 0; i < n; ++i, ++data) {
auto size = sizes[i];
if (size > reader.Available()) {
return false;
}
#ifdef USE_TSTRING
// TODO(dero): Consider adding resize_uninitialized() to tstring once the
// tstring placeholder is replaced with the actual implementation.
//
// Currently, in the case of USE_TSTRING, the placeholder tstring class
// encapsulates a single std::string. We avoid using
// gtl::STLStringResizeUninitialized (and its associated header-include) in
// tstring.h as we have no intention in using it in the actual
// implementation. Thus, in the interim, we resort to resize().
data->resize(size);
reader.ReadN(size, data->data());
#else // USE_TSTRING
gtl::STLStringResizeUninitialized(data, size);
reader.ReadN(size, gtl::string_as_array(data));
#endif // USE_TSTRING
}
return true;
}
void CopyFromArray(Cord* c, const char* base, size_t bytes) {
c->CopyFrom(base, bytes);
}
class CordStringListEncoderImpl : public StringListEncoder {
public:
explicit CordStringListEncoderImpl(Cord* out) : out_(out) {}
~CordStringListEncoderImpl() override = default;
void Append(const protobuf::MessageLite& m) override {
::strings::CordAppendVarint(m.ByteSizeLong(), out_);
m.AppendToString(&rest_);
}
void Append(const string& s) override {
::strings::CordAppendVarint(s.length(), out_);
rest_.append(s.data(), s.size());
}
void Finalize() override { out_->Append(rest_); }
private:
Cord* out_;
string rest_;
};
class CordStringListDecoderImpl : public StringListDecoder {
public:
explicit CordStringListDecoderImpl(const Cord& in) : reader_(in) {}
~CordStringListDecoderImpl() override = default;
bool ReadSizes(std::vector<uint32>* sizes) override {
int64 total = 0;
for (auto& size : *sizes) {
if (!::strings::CordReaderReadVarint(&reader_, &size)) return false;
total += size;
}
if (total != static_cast<int64>(reader_.Available())) {
return false;
}
return true;
}
const char* Data(uint32 size) override {
tmp_.resize(size);
reader_.ReadN(size, tmp_.data());
return tmp_.data();
}
private:
CordReader reader_;
std::vector<char> tmp_;
};
std::unique_ptr<StringListEncoder> NewStringListEncoder(Cord* out) {
return std::unique_ptr<StringListEncoder>(new CordStringListEncoderImpl(out));
}
std::unique_ptr<StringListDecoder> NewStringListDecoder(const Cord& in) {
return std::unique_ptr<StringListDecoder>(new CordStringListDecoderImpl(in));
}
#endif // defined(TENSORFLOW_PROTOBUF_USES_CORD)
} // namespace port
} // namespace tensorflow