blob: 2b42d81f475189f74a934c3aeed7d7fc34d4eb53 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
#include <algorithm>
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/contrib/tensorrt/test/utils.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "cuda/include/cuda_runtime_api.h"
namespace tensorflow {
namespace tensorrt {
static Logger logger;
using ::nvinfer1::IRuntime;
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
// A helper class to call done() when destructed for asynchronous execution.
// Helps simultaneous execution of native and TRT engines.
class AsyncHelper : public tensorflow::core::RefCounted {
public:
AsyncHelper(AsyncOpKernel::DoneCallback done) { done_ = done; }
~AsyncHelper() override { done_(); }
private:
AsyncOpKernel::DoneCallback done_;
};
#define TYPECASE(dt, X, Y) \
case dt: { \
return (void*)X->flat<tensorflow::EnumToDataType<dt>::Type>().data(); \
}
void* GetTensorAddress(const Tensor* tensor_ptr) {
auto tensor_type = tensor_ptr->dtype();
switch (tensor_type) {
TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr);
TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr);
TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr);
default: {
LOG(ERROR) << "Unsupported Data type "
<< tensorflow::DataTypeString(tensor_type);
return nullptr;
}
}
}
tensorflow::Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) {
VLOG(1) << "Constructing function handle";
auto lib = ctx->function_library();
if (lib == nullptr) {
return tensorflow::errors::Internal("Context function library is null");
}
auto fdef = lib->GetFunctionLibraryDefinition()->Find(funcdef_name_);
if (fdef == nullptr) {
return tensorflow::errors::Internal("Native FunctionDef ", funcdef_name_,
" can't be found in function library");
}
tensorflow::FunctionLibraryRuntime::InstantiateOptions inst_ops;
inst_ops.overlay_lib = nullptr;
inst_ops.state_handle = "";
inst_ops.target = ctx->device()->name();
native_func_ = 0;
auto status = lib->Instantiate(funcdef_name_, AttrSlice(&fdef->attr()),
inst_ops, &native_func_);
if (!status.ok()) {
LOG(ERROR) << " Instantiating native function " << funcdef_name_
<< " failed!";
}
return status;
}
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
: AsyncOpKernel(context) {
// read serialized_engine
OP_REQUIRES_OK(context,
context->GetAttr("serialized_segment", &serialized_segment_));
OP_REQUIRES_OK(context,
context->GetAttr("workspace_size_bytes", &workspace_size_));
OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
if (!static_engine_) {
if (!segment_graph_.ParseFromString(serialized_segment_)) {
LOG(ERROR) << "Parsing segment graph failed!";
context->SetStatus(tensorflow::errors::InvalidArgument(
"Failed to parse segment graphdef!"));
return;
}
serialized_segment_.resize(0);
}
VLOG(1) << "Constructing " << name();
string precision_string;
OP_REQUIRES_OK(context,
context->GetAttr("precision_mode", &precision_string));
string calibration_data;
OP_REQUIRES_OK(context,
context->GetAttr("calibration_data", &calibration_data));
OP_REQUIRES_OK(context,
context->GetAttr("segment_funcdef_name", &funcdef_name_));
OP_REQUIRES_OK(context, GetPrecisionMode(precision_string, &precision_mode_));
calibration_mode_ =
(precision_mode_ == INT8MODE && calibration_data.size() == 0);
if (calibration_data.size()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data));
calibration_data.resize(0);
}
native_func_ = tensorflow::kInvalidHandle;
OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
&max_cached_engines_));
OP_REQUIRES_OK(context,
context->GetAttr("fixed_input_size", &fixed_input_size_));
OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches",
&cached_engine_batches_));
std::sort(cached_engine_batches_.begin(), cached_engine_batches_.end());
if (VLOG_IS_ON(1)) {
string s("Engine Batches= ");
for (auto i : cached_engine_batches_) {
StrAppend(&s, i, " ");
}
VLOG(1) << s;
}
}
void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
AsyncHelper* helper) {
if (!calibration_mode_) {
VLOG(1) << "Executing native engine";
}
std::vector<Tensor> inputs;
std::vector<Tensor>* outputs = new std::vector<Tensor>();
if (native_func_ == tensorflow::kInvalidHandle) {
auto status = ConstructFunctionHandle(ctx);
if (!status.ok()) {
LOG(ERROR) << "Couldn't construct function handle " << funcdef_name_;
ctx->SetStatus(status);
return;
}
}
auto lib = ctx->function_library();
tensorflow::FunctionLibraryRuntime::Options opts;
opts.step_id = ctx->step_id();
opts.rendezvous = ctx->rendezvous();
opts.cancellation_manager = ctx->cancellation_manager();
opts.runner = ctx->runner();
for (int i = 0; i < ctx->num_inputs(); i++) {
inputs.push_back(ctx->input(i));
}
helper->Ref(); // Increment count for calculating native graph
VLOG(1) << "Executing native segment " << name();
lib->Run(opts, native_func_, inputs, outputs,
[this, ctx, outputs, helper](const tensorflow::Status& s) {
tensorflow::core::ScopedUnref sc(helper);
VLOG(1) << "Native Segment completed";
if (!s.ok()) {
ctx->SetStatus(s);
return;
}
for (size_t t = 0; t < outputs->size(); ++t) {
ctx->set_output(t, outputs->at(t));
}
test::AddTestValue(StrCat(this->name(), ":ExecuteNativeSegment"),
"done");
delete outputs;
});
}
void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
AsyncHelper* helper) {
helper->Ref();
tensorflow::core::ScopedUnref sc(helper);
// TODO(aaroey): remove the ResourceMgr singleton.
auto trt_rm = TRTResourceManager::instance();
auto res_mgr = trt_rm->getManager("TRTCalibration");
TRTCalibrationResource* calib_res = nullptr;
auto status = res_mgr->LookupOrCreate(
funcdef_name_, "Calibrator", &calib_res,
{[ctx, this](TRTCalibrationResource** cr) -> tensorflow::Status {
return this->AllocateCalibrationResources(ctx, cr);
}});
if (!status.ok()) {
ctx->SetStatus(status);
return;
}
int num_inputs = ctx->num_inputs();
// Pass input data to calibrator
std::unordered_map<string, void*> input_data;
for (int i = 0; i < num_inputs; i++) {
const Tensor& t = ctx->input(i);
void* data_address = GetTensorAddress(&t);
if (data_address == nullptr) {
ctx->SetStatus(tensorflow::errors::InvalidArgument(
"Unsupported data type encountered in input ", i));
return;
}
// Check the allocated buffer is sufficient for input
const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx);
CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
input_data.emplace(StrCat(kInputPHName, i), data_address);
}
VLOG(2) << "Filled map for sending";
// copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
const cudaStream_t* stream = CHECK_NOTNULL(
reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
->stream()
->implementation()
->GpuStreamMemberHack()));
calib_res->calibrator_->setBatch(input_data, *stream);
test::AddTestValue(StrCat(name(), ":ExecuteCalibration"), "done");
VLOG(2) << "Passed calibration data";
ExecuteNativeSegment(ctx, helper);
}
int TRTEngineOp::GetEngineBatch(OpKernelContext* ctx) {
int num_batch = ctx->input(0).shape().dim_size(0);
int smallest_engine = 0;
for (const auto i : cached_engine_batches_) {
if (i >= num_batch) {
smallest_engine = i;
break;
}
}
// TODO(sami): Need an LRU here
if (smallest_engine == 0) {
if (max_cached_engines_ > cached_engine_batches_.size()) {
smallest_engine = num_batch;
cached_engine_batches_.push_back(num_batch);
VLOG(1) << "Running with batch size " << num_batch;
} else {
string msg =
StrCat("Engine buffer is full. buffer limit=", max_cached_engines_,
", current entries=");
for (auto i : cached_engine_batches_) StrAppend(&msg, i, ",");
StrAppend(&msg, " requested batch=", num_batch);
LOG(WARNING) << msg;
return -1;
}
}
return smallest_engine;
}
void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
AsyncOpKernel::DoneCallback done) {
auto helper = new AsyncHelper(done);
tensorflow::core::ScopedUnref sc(helper);
if (calibration_mode_) {
ExecuteCalibration(ctx, helper);
return;
}
const int smallest_engine = GetEngineBatch(ctx);
if (smallest_engine < 0) {
LOG(WARNING) << "Failed to get engine batch, running native segment for "
<< name();
ExecuteNativeSegment(ctx, helper);
return;
}
const int num_batch = ctx->input(0).shape().dim_size(0);
auto& engine_ctx_pair = GetEngine(smallest_engine, ctx);
auto& trt_engine_ptr = engine_ctx_pair.first;
if (!trt_engine_ptr) {
LOG(WARNING) << "Engine retrieval for batch size " << num_batch
<< " failed. Running native segment for " << name();
ExecuteNativeSegment(ctx, helper);
return;
}
const bool retry = ExecuteTrtEngine(ctx, num_batch, trt_engine_ptr.get(),
engine_ctx_pair.second.get());
if (retry) {
LOG(WARNING) << "Failed to execute engine, "
<< "retrying with native segment for " << name();
ExecuteNativeSegment(ctx, helper);
return;
}
}
bool TRTEngineOp::ExecuteTrtEngine(
OpKernelContext* ctx, const int num_batch,
nvinfer1::ICudaEngine* trt_engine_ptr,
nvinfer1::IExecutionContext* trt_execution_context_ptr) {
const bool kRetry = true;
const int num_binding = ctx->num_inputs() + ctx->num_outputs();
std::vector<void*> buffers(num_binding);
for (int i = 0; i < ctx->num_inputs(); i++) {
const string input_name = StrCat(kInputPHName, i);
const size_t binding_index =
trt_engine_ptr->getBindingIndex(input_name.c_str());
if (binding_index == -1) {
LOG(ERROR) << "Input node not found, at " << input_name;
return kRetry;
}
const Tensor& input_tensor = ctx->input(i);
const TensorShape& input_shape = input_tensor.shape();
if (num_batch != input_shape.dim_size(0)) {
LOG(ERROR) << "Input data has inconsistent batch size: " << num_batch
<< " vs " << input_shape.dim_size(0);
return kRetry;
}
auto dtype = trt_engine_ptr->getBindingDataType(binding_index);
switch (dtype) {
case nvinfer1::DataType::kFLOAT:
buffers[binding_index] = (void*)(input_tensor.flat<float>().data());
break;
case nvinfer1::DataType::kHALF:
LOG(ERROR) << "FP16 inputs are not supported yet!";
return kRetry;
case nvinfer1::DataType::kINT8:
LOG(ERROR) << "INT8 inputs are not supported yet!";
return kRetry;
#if NV_TENSORRT_MAJOR > 3
case nvinfer1::DataType::kINT32:
buffers[binding_index] = (void*)(input_tensor.flat<int32>().data());
break;
#endif
default:
LOG(ERROR) << "Unknown TRT data type: " << int(dtype);
return kRetry;
}
}
for (int i = 0; i < ctx->num_outputs(); i++) {
// Create an output tensor
const string output_name = StrCat(kOutputPHName, i);
const size_t binding_index =
trt_engine_ptr->getBindingIndex(output_name.c_str());
Tensor* output_tensor = nullptr;
TensorShape output_shape;
if (binding_index != -1) {
auto dims = trt_engine_ptr->getBindingDimensions(binding_index);
std::vector<int> trt_shape(dims.nbDims + 1);
trt_shape[0] = num_batch;
for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j];
auto status = TensorShapeUtils::MakeShape(
trt_shape.data(), trt_shape.size(), &output_shape);
if (!status.ok()) {
LOG(ERROR) << "Failed to get output shape: " << status;
return kRetry;
}
} else {
LOG(ERROR) << "Output node not found, at " << output_name;
return kRetry;
}
auto status = ctx->allocate_output(i, output_shape, &output_tensor);
if (!status.ok()) {
LOG(ERROR) << "Allocating output failed with " << status;
ctx->SetStatus(status);
// Do not retry since we cannot allocate the same output twice.
// TODO(aaroey): ideally we should retry, fix this.
return !kRetry;
}
auto dtype = trt_engine_ptr->getBindingDataType(binding_index);
switch (dtype) {
case nvinfer1::DataType::kFLOAT:
buffers[binding_index] =
reinterpret_cast<void*>(output_tensor->flat<float>().data());
break;
case nvinfer1::DataType::kHALF:
LOG(WARNING) << "half size is not supported yet!";
return kRetry;
case nvinfer1::DataType::kINT8:
LOG(WARNING) << "int8 is not supported yet!";
return kRetry;
#if NV_TENSORRT_MAJOR > 3
case nvinfer1::DataType::kINT32:
buffers[binding_index] =
reinterpret_cast<void*>(output_tensor->flat<int32>().data());
break;
#endif
default:
LOG(WARNING) << "Unknown TRT data type: " << static_cast<int>(dtype);
return kRetry;
}
}
// Copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
const cudaStream_t* stream = CHECK_NOTNULL(
reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
->stream()
->implementation()
->GpuStreamMemberHack()));
// TODO(jie): trt enqueue does not return error
auto ret = trt_execution_context_ptr->enqueue(num_batch, &buffers[0], *stream,
nullptr);
if (!ret) {
LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name();
return kRetry;
}
test::AddTestValue(StrCat(name(), ":ExecuteTrtEngine"), "done");
// Synchronization will be done by TF.
return !kRetry;
}
TRTEngineOp::~TRTEngineOp() {
// We need to manually destroy the engine and execution context before
// the allocator is destructed.
for (auto& eng : engine_map_) {
eng.second.first.reset();
eng.second.second.reset();
}
allocator_.reset();
}
nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) {
if (allocator_) return allocator_.get();
auto device = ctx->device();
auto alloc = device->GetAllocator(tensorflow::AllocatorAttributes());
if (!alloc) {
LOG(ERROR) << "Can't find device allocator for gpu device "
<< device->name();
return nullptr;
}
allocator_.reset(new TRTDeviceAllocator(alloc));
return allocator_.get();
}
TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
OpKernelContext* ctx) {
static EngineCtxPair null_pair = {
TrtUniquePtrType<nvinfer1::ICudaEngine>(nullptr),
TrtUniquePtrType<nvinfer1::IExecutionContext>(nullptr)};
// TODO(sami): This method needs to be re-written to use resource manager and
// with LRU mechanism option.
tensorflow::mutex_lock lock(engine_mutex_);
if (static_engine_) {
if (engine_map_.size()) {
if (engine_map_.begin()->first >= batch_size) {
return engine_map_.begin()->second;
}
return null_pair;
}
TrtUniquePtrType<IRuntime> infer(nvinfer1::createInferRuntime(logger));
#if NV_TENSORRT_MAJOR > 3
auto allocator = GetAllocator(ctx);
if (allocator == nullptr) {
return null_pair;
}
infer->setGpuAllocator(allocator);
#endif
TrtUniquePtrType<nvinfer1::ICudaEngine> static_engine(
infer->deserializeCudaEngine(serialized_segment_.c_str(),
serialized_segment_.size(),
PluginFactoryTensorRT::GetInstance()));
auto raw_static_engine = static_engine.get();
const auto max_batch_size = raw_static_engine->getMaxBatchSize();
engine_map_[max_batch_size] = {
std::move(static_engine),
TrtUniquePtrType<nvinfer1::IExecutionContext>(
raw_static_engine->createExecutionContext())};
// Runtime is safe to delete after engine creation
serialized_segment_.clear();
if (max_batch_size < batch_size) {
return null_pair;
}
return engine_map_.at(max_batch_size);
} // static_engine_
// Handle the dynamic engine case.
auto engine_it = engine_map_.find(batch_size);
if (engine_it == engine_map_.end() &&
engine_map_.size() < (size_t)max_cached_engines_) {
nvinfer1::IGpuAllocator* allocator = nullptr;
#if NV_TENSORRT_MAJOR > 3
allocator = GetAllocator(ctx);
if (allocator == nullptr) {
return null_pair;
}
#endif
std::vector<tensorflow::PartialTensorShape> shapes;
for (int i = 0; i < ctx->num_inputs(); ++i) {
shapes.emplace_back(ctx->input(i).shape());
}
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
bool convert_successfully = false;
VLOG(0) << name() << " Constructing a new engine with batch size "
<< batch_size;
// Up to this point, calibrator_ can never be empty, since otherwise it
// means calibration_mode_ is true and this path won't get executed.
auto status = convert::ConvertGraphDefToEngine(
segment_graph_, precision_mode_, batch_size, workspace_size_, shapes,
&logger, allocator, calibrator_.get(), &engine, &convert_successfully);
if (!status.ok()) {
if (convert_successfully) {
// This means it fail to build the engine even when the network is built
// successfully, probably due to internal issues. In this case we don't
// retry in the future.
engine_map_[batch_size] = {nullptr, nullptr};
}
LOG(WARNING) << "Engine creation for batch size " << batch_size
<< " failed " << status;
return null_pair;
}
VLOG(1) << "Conversion is done";
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
engine->createExecutionContext());
engine_map_[batch_size] = {std::move(engine), std::move(exec_context)};
}
return engine_map_.at(batch_size);
}
tensorflow::Status TRTEngineOp::AllocateCalibrationResources(
OpKernelContext* ctx, TRTCalibrationResource** cr) {
auto cres = new TRTCalibrationResource();
*cr = cres;
// Get the allocator.
auto alloc = ctx->device()->GetAllocator(tensorflow::AllocatorAttributes());
if (!alloc) {
LOG(WARNING) << "Can't get device allocator will not be able to "
"allocate memory from TensorFlow memory pool";
cres->allocator_.reset(new TRTCudaAllocator);
} else {
cres->allocator_.reset(new TRTDeviceAllocator(alloc));
}
// Get the input shapes.
const int batch_size = ctx->input(0).dim_size(0);
const int num_inputs = ctx->num_inputs();
std::vector<tensorflow::PartialTensorShape> shapes;
dev_tensors_.resize(num_inputs);
VLOG(1) << " Constructing calibrator";
for (int i = 0; i < num_inputs; i++) {
// allocate workspace on device for inputs
const tensorflow::Tensor& t = ctx->input(i);
shapes.emplace_back(t.shape());
Tensor* device_tensor;
TF_RETURN_IF_ERROR(ctx->allocate_persistent(
t.dtype(), t.shape(), &dev_tensors_.at(i), &device_tensor));
CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
void* device_address = GetTensorAddress(device_tensor);
if (device_address == nullptr) {
return tensorflow::errors::InvalidArgument(
"Unsupported data type encountered in input ", i);
}
device_buffers_.emplace(
StrCat(kInputPHName, i),
std::pair<void*, size_t>(device_address, device_tensor->TotalBytes()));
}
cres->calibrator_.reset(
new TRTInt8Calibrator(device_buffers_, batch_size, name()));
const string label(name());
auto segment_graph = &segment_graph_;
const int cuda_gpu_id = ctx->device()->tensorflow_gpu_device_info()->gpu_id;
if (cuda_gpu_id < 0) {
LOG(ERROR) << "Can't get gpu_device_info from context->device()";
return tensorflow::errors::InvalidArgument(
"Context->device doesn't contain device info!");
}
const int64 workspace_size_bytes = workspace_size_;
cres->thr_.reset(new std::thread([cres, label, segment_graph, shapes,
cuda_gpu_id, workspace_size_bytes]() {
VLOG(0) << "Starting calibration thread on device " << cuda_gpu_id
<< ", Calibration Resource @ " << cres;
auto err = cudaSetDevice(cuda_gpu_id);
if (err != cudaSuccess) {
// TODO(aaroey): should return error here.
LOG(ERROR) << "Couldn't set cuda device to " << cuda_gpu_id
<< " in calibration thread";
}
// ConvertGraphDefToEngine() will try to build the engine. This thread
// will loop inside buildCudaEngine() consuming the calibration data
// that is set by the TF op, and drive the builder until calibrator returns
// false. Engine is discarded after calibration table is generated
//
// TODO(aaroey): maybe setting the max batch size using the python
// calibration wrapper class.
auto s = convert::ConvertGraphDefToEngine(
*segment_graph, INT8MODE, cres->calibrator_->getBatchSize(),
workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(),
cres->calibrator_.get(), &cres->engine_,
/*convert_successfully=*/nullptr);
if (!s.ok()) {
LOG(ERROR) << "Calibration failed: " << s;
cres->calibrator_->setDone(); // Ignore further pushes
}
VLOG(1) << "Calibration loop terminated " << label;
}));
VLOG(1) << "initialized calibrator resource";
return tensorflow::Status::OK();
}
REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA