blob: 925521c7c73ae8ca1896e9e0d085cf357e2c194a [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <sstream>
#include <string>
#include "tensorflow/c/kernels/tensor_shape_utils.h"
#include "tensorflow/c/kernels.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/types.h"
namespace {
// Struct that stores the status and TF_Tensor inputs to the opkernel.
// Used to delete tensor and status in its destructor upon kernel return.
struct Params {
TF_Tensor* tags;
TF_Tensor* values;
TF_Status* status;
Params(TF_OpKernelContext* ctx) : tags(nullptr),
values(nullptr),
status(nullptr) {
status = TF_NewStatus();
TF_GetInput(ctx, 0, &tags, status);
if (TF_GetCode(status) == TF_OK) {
TF_GetInput(ctx, 1, &values, status);
}
};
~Params() {
TF_DeleteStatus(status);
TF_DeleteTensor(tags);
TF_DeleteTensor(values);
}
};
// dummy functions used for kernel registration
void* ScalarSummaryOp_Create(TF_OpKernelConstruction* ctx) {
return nullptr;
}
void ScalarSummaryOp_Delete(void* kernel) {
return;
}
// Helper functions for compute method
bool IsSameSize(TF_Tensor* tensor1, TF_Tensor* tensor2);
// Returns a string representation of a single tag or empty string if there
// are multiple tags
std::string SingleTag(TF_Tensor* tags);
template<typename T>
void ScalarSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
Params params(ctx);
if (TF_GetCode(params.status) != TF_OK) {
TF_OpKernelContext_Failure(ctx, params.status);
return;
}
if (!IsSameSize(params.tags, params.values)) {
std::ostringstream err;
err << "tags and values not the same shape: "
<< tensorflow::ShapeDebugString(params.tags) << " != "
<< tensorflow::ShapeDebugString(params.values)
<< SingleTag(params.tags);
TF_SetStatus(params.status, TF_INVALID_ARGUMENT, err.str().c_str());
TF_OpKernelContext_Failure(ctx, params.status);
return;
}
// Convert tags and values tensor to array to access elements by index
tensorflow::Summary s;
auto tags_array = static_cast<tensorflow::tstring*>(
TF_TensorData(params.tags));
auto values_array = static_cast<T*>(TF_TensorData(params.values));
// Copy tags and values into summary protobuf
for (int i = 0; i < TF_TensorElementCount(params.tags); ++i) {
tensorflow::Summary::Value* v = s.add_value();
const tensorflow::tstring& Ttags_i = tags_array[i];
v->set_tag(Ttags_i.data(), Ttags_i.size());
v->set_simple_value(float(values_array[i]));
}
TF_Tensor* summary_tensor = TF_AllocateOutput(ctx, 0,
TF_ExpectedOutputDataType(ctx, 0), nullptr, 0,
sizeof(tensorflow::tstring), params.status);
if (TF_GetCode(params.status) != TF_OK){
TF_DeleteTensor(summary_tensor);
TF_OpKernelContext_Failure(ctx, params.status);
return;
}
tensorflow::tstring* output_tstring = reinterpret_cast<tensorflow::tstring*>(
TF_TensorData(summary_tensor));
CHECK(SerializeToTString(s, output_tstring));
TF_DeleteTensor(summary_tensor);
}
bool IsSameSize(TF_Tensor* tensor1, TF_Tensor* tensor2) {
if (TF_NumDims(tensor1) != TF_NumDims(tensor2)) {
return false;
}
for (int d = 0; d < TF_NumDims(tensor1); d++) {
if (TF_Dim(tensor1, d) != TF_Dim(tensor2, d)) {
return false;
}
}
return true;
}
std::string SingleTag(TF_Tensor* tags) {
if (TF_TensorElementCount(tags) == 1) {
const char* single_tag = static_cast<tensorflow::tstring*>(
TF_TensorData(tags))->c_str();
return tensorflow::strings::StrCat(" (tag '", single_tag, "')");
}
else {
return "";
}
}
template <typename T>
void RegisterScalarSummaryOpKernel() {
TF_Status* status = TF_NewStatus();
{
auto* builder = TF_NewKernelBuilder("ScalarSummary",
tensorflow::DEVICE_CPU,
&ScalarSummaryOp_Create,
&ScalarSummaryOp_Compute<T>,
&ScalarSummaryOp_Delete);
TF_KernelBuilder_TypeConstraint(builder, "T",
static_cast<TF_DataType>(tensorflow::DataTypeToEnum<T>::v()), status);
CHECK_EQ(TF_OK, TF_GetCode(status))
<< "Error while adding type constraint";
TF_RegisterKernelBuilder("ScalarSummary", builder, status);
CHECK_EQ(TF_OK, TF_GetCode(status))
<< "Error while registering Scalar Summmary kernel";
}
TF_DeleteStatus(status);
}
// A dummy static variable initialized by a lambda whose side-effect is to
// register the ScalarSummary kernel.
TF_ATTRIBUTE_UNUSED bool IsScalarSummaryOpKernelRegistered = []() {
if (SHOULD_REGISTER_OP_KERNEL("ScalarSummary")) {
RegisterScalarSummaryOpKernel<tensorflow::int64>();
RegisterScalarSummaryOpKernel<tensorflow::uint64>();
RegisterScalarSummaryOpKernel<tensorflow::int32>();
RegisterScalarSummaryOpKernel<tensorflow::uint32>();
RegisterScalarSummaryOpKernel<tensorflow::uint16>();
RegisterScalarSummaryOpKernel<tensorflow::int16>();
RegisterScalarSummaryOpKernel<tensorflow::int8>();
RegisterScalarSummaryOpKernel<tensorflow::uint8>();
RegisterScalarSummaryOpKernel<Eigen::half>();
RegisterScalarSummaryOpKernel<tensorflow::bfloat16>();
RegisterScalarSummaryOpKernel<float>();
RegisterScalarSummaryOpKernel<double>();
}
return true;
}();
} // namespace