blob: 5e08e5a7a60f51768c285101294906e29ccb1670 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
#define TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
#include <iostream>
#include <type_traits>
#include <utility>
#include <vector>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/abi.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
// Type used for tag-dispatch of the Encode/Decode Variant implementations. This
// template can determine whether the first type parameter `T` is one of the
// following:
//
// * A POD type (TypeResolver<T, true>)
// * A tensorflow::Tensor (TypeResolver<T, false, true>)
// * A protocol buffer (TypeResolver<T, false, false, true>)
// * None of the above (TypeResolver<T, false, false, false>)
//
template <typename T, bool = std::is_pod<typename std::decay<T>::type>::value,
bool = std::is_same<typename std::decay<T>::type,
::tensorflow::Tensor>::value,
bool = std::is_base_of<protobuf::MessageLite,
typename std::decay<T>::type>::value>
struct TypeResolver {};
// Specialization for POD type
template <typename T>
void EncodeVariantImpl(const T& value, TypeResolver<T, true /* is_pod */>,
VariantTensorData* data) {
data->set_metadata(value);
}
// Specialization for tensorflow::Tensor
template <typename T>
void EncodeVariantImpl(const T& value,
TypeResolver<T, false /* is_pod */, true /* Tensor */>,
VariantTensorData* data) {
data->tensors_.clear();
data->tensors_.push_back(value);
}
// Specialization for protobuf
template <typename T>
void EncodeVariantImpl(const T& value,
TypeResolver<T, false /* is_pod */, false /* Tensor */,
true /* protobuf */>,
VariantTensorData* data) {
value.SerializeToString(&data->metadata_);
}
// Specialization for other types
template <typename T>
void EncodeVariantImpl(const T& value,
TypeResolver<T, false /* is_pod */, false /* Tensor */,
false /* protobuf */>,
VariantTensorData* data) {
value.Encode(data);
}
// Specialization for POD type
template <typename T>
bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, true /* is_pod */, false /* Tensor */,
false /* protobuf */>,
T* value) {
return data.get_metadata(value);
}
// Specialization for tensorflow::Tensor
template <typename T>
bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, false /* is_pod */, true /* Tensor */,
false /* protobuf */>,
T* value) {
*value = data.tensors(0);
return true;
}
// Specialization for protobuf
template <typename T>
bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, false /* is_pod */, false /* Tensor */,
true /* protobuf */>,
T* value) {
string metadata;
data.get_metadata(&metadata);
return value->ParseFromString(std::move(metadata));
}
// Specialization for other types
template <typename T>
bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, false /* is_pod */, false /* Tensor */,
false /* protobuf */>,
T* value) {
return value->Decode(std::move(data));
}
template <typename C, typename = void>
struct has_type_name : std::false_type {};
template <typename C>
struct has_type_name<
C, typename std::enable_if<std::is_same<
decltype(std::declval<C>().TypeName()), string>::value>::type>
: std::true_type {};
template <typename T, bool = has_type_name<typename std::decay<T>::type>::value,
bool = std::is_same<typename std::decay<T>::type,
::tensorflow::Tensor>::value,
bool = std::is_base_of<protobuf::MessageLite,
typename std::decay<T>::type>::value>
struct TypeNameResolver {};
template <typename T>
string TypeNameVariantImpl(const T& value,
TypeNameResolver<T, true /* has_type_name */>) {
return value.TypeName();
}
template <typename T>
string TypeNameVariantImpl(
const T& value,
TypeNameResolver<T, false /* has_type_name */, true /* Tensor */>) {
return "tensorflow::Tensor";
}
template <typename T>
string TypeNameVariantImpl(
const T& value, TypeNameResolver<T, false /* has_type_name */,
false /* Tensor */, true /* protobuf */>) {
return value.GetTypeName();
}
template <typename T>
string TypeNameVariantImpl(
const T& value,
TypeNameResolver<T, false /* has_type_name */, false /* Tensor */,
false /* protobuf */>) {
return port::MaybeAbiDemangle(MakeTypeIndex<T>().name());
}
template <typename T>
string TypeNameVariant(const T& value) {
return TypeNameVariantImpl(value, TypeNameResolver<T>());
}
template <typename C, typename = void>
struct has_debug_string : std::false_type {};
template <typename C>
struct has_debug_string<
C, typename std::enable_if<std::is_same<
decltype(std::declval<C>().DebugString()), string>::value>::type>
: std::true_type {};
template <typename C, typename = void>
struct can_strcat : std::false_type {};
template <typename C>
struct can_strcat<
C, typename std::enable_if<std::is_same<
decltype(strings::StrCat(std::declval<C>())), string>::value>::type>
: std::true_type {};
template <typename T,
bool = has_debug_string<typename std::decay<T>::type>::value,
bool = can_strcat<typename std::decay<T>::type>::value>
struct DebugStringResolver {};
// TODO(ebrevdo): Expand DebugStringResolver to return TypeString if
// there is no StrCat<T>() constructor.
template <typename T>
string DebugStringVariantImpl(
const T& value, DebugStringResolver<T, true /* has_debug_string */>) {
return value.DebugString();
}
template <typename T>
string DebugStringVariantImpl(
const T& value, DebugStringResolver<T, false /* has_debug_string */,
true /* can_strcat */>) {
return strings::StrCat(value);
}
template <typename T>
string DebugStringVariantImpl(
const T& value, DebugStringResolver<T, false /* has_debug_string */,
false /* can_strcat */>) {
return "?";
}
template <typename T>
string DebugStringVariant(const T& value) {
return DebugStringVariantImpl(value, DebugStringResolver<T>());
}
template <typename T>
void EncodeVariant(const T& value, VariantTensorData* data) {
EncodeVariantImpl(value, TypeResolver<T>(), data);
data->set_type_name(TypeNameVariant(value));
}
template <typename T>
bool DecodeVariant(VariantTensorData* data, T* value) {
return DecodeVariantImpl(std::move(*data), TypeResolver<T>(), value);
}
template <typename T>
void EncodeVariant(const T& value, string* buf) {
VariantTensorData data;
EncodeVariantImpl(value, TypeResolver<T>(), &data);
data.set_type_name(TypeNameVariant(value));
DCHECK(buf != nullptr);
data.SerializeToString(buf);
}
template <typename T>
bool DecodeVariant(string* buf, T* value) {
VariantTensorData data;
if (!data.ParseFromString(*buf)) return false;
if (!DecodeVariantImpl(std::move(data), TypeResolver<T>(), value)) {
return false;
}
return true;
}
// Specializations for VariantTensorDataProto
template <>
string TypeNameVariant(const VariantTensorDataProto& value);
template <>
void EncodeVariant(const VariantTensorDataProto& value,
VariantTensorData* data);
template <>
bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value);
template <>
void EncodeVariant(const VariantTensorDataProto& value, string* buf);
template <>
bool DecodeVariant(string* buf, VariantTensorDataProto* value);
// Encodes an array of Variant objects in to the given StringListEncoder.
// `variant_array` is assumed to point to an array of `n` Variant objects.
void EncodeVariantList(const Variant* variant_array, int64 n,
std::unique_ptr<port::StringListEncoder> e);
// Decodes an array of Variant objects from the given StringListDecoder.
// `variant_array` is assumed to point to an array of `n` Variant objects.
bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d,
Variant* variant_array, int64 n);
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_