blob: 123255bd3e2a8da001afddbb21aee0a21ac7979f [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <memory>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.pb.h" // NOLINT
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/io/record_reader.h"
#include "tensorflow/core/lib/io/record_writer.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#if GOOGLE_CUDA && GOOGLE_TENSORRT
#include "third_party/tensorrt/NvInfer.h"
namespace tensorflow {
namespace tensorrt {
using ::nvinfer1::IRuntime;
class CreateTRTResourceHandle : public OpKernel {
public:
explicit CreateTRTResourceHandle(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("resource_name", &resource_name_));
}
void Compute(OpKernelContext* ctx) override {
{
mutex_lock l(mutex_);
if (!initialized_) {
AllocatorAttributes attr;
attr.set_on_host(true);
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
&handle_, attr));
VLOG(1) << "Creating TRT engine cache resource handle for op "
<< resource_name_ << " on device " << ctx->device()->name();
handle_.scalar<ResourceHandle>()() =
MakeResourceHandle<TRTEngineCacheResource>(
ctx, std::string(kTfTrtContainerName), resource_name_);
initialized_ = true;
}
}
ctx->set_output(0, handle_);
}
private:
string resource_name_;
Tensor handle_;
mutex mutex_;
bool initialized_ TF_GUARDED_BY(mutex_) = false;
TF_DISALLOW_COPY_AND_ASSIGN(CreateTRTResourceHandle);
};
REGISTER_KERNEL_BUILDER(Name("CreateTRTResourceHandle")
.Device(DEVICE_GPU)
.HostMemory("resource_handle"),
CreateTRTResourceHandle);
class InitializeTRTResource : public OpKernel {
public:
explicit InitializeTRTResource(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(
ctx, ctx->GetAttr("max_cached_engines_count", &max_cached_engines_));
}
void Compute(OpKernelContext* ctx) override {
ResourceHandle handle = HandleFromInput(ctx, 0);
core::RefCountPtr<TRTEngineCacheResource> resource;
OP_REQUIRES_OK(
ctx, LookupOrCreateResource<TRTEngineCacheResource>(
ctx, handle, &resource,
[this, ctx](TRTEngineCacheResource** resource) -> Status {
*resource = new TRTEngineCacheResource(
ctx, this->max_cached_engines_);
return Status::OK();
}));
auto allocator = resource->allocator_.get();
OP_REQUIRES(ctx, allocator != nullptr,
errors::Internal("Not able to initialize TRT engine cache when "
"GPU allocator is empty."));
OP_REQUIRES(ctx, resource->cache_.size() == 0,
errors::Internal("Expect engine cache to be empty, but got ",
resource->cache_.size(), " entries."));
// Get the file name.
const string& filename = ctx->input(1).scalar<tstring>()();
OP_REQUIRES(ctx, !filename.empty(),
errors::InvalidArgument("filename cannot be empty."));
// Parse the serialized engines and add them to the cache.
std::unique_ptr<RandomAccessFile> file;
OP_REQUIRES_OK(ctx, ctx->env()->NewRandomAccessFile(filename, &file));
auto reader = absl::make_unique<io::RecordReader>(file.get());
uint64 offset = 0;
int num_loaded_engine = 0;
do {
tstring record;
Status status = reader->ReadRecord(&offset, &record);
if (errors::IsOutOfRange(status)) break;
TRTEngineInstance engine_instance;
engine_instance.ParseFromString(record);
std::vector<TensorShape> engine_input_shapes;
const auto& input_shapes = engine_instance.input_shapes();
engine_input_shapes.reserve(input_shapes.size());
for (const TensorShapeProto& shape : input_shapes) {
engine_input_shapes.emplace_back(shape);
}
TrtUniquePtrType<IRuntime> infer(
nvinfer1::createInferRuntime(TRTEngineCacheResource::GetLogger()));
infer->setGpuAllocator(allocator);
TrtUniquePtrType<nvinfer1::ICudaEngine> engine(
infer->deserializeCudaEngine(
engine_instance.serialized_engine().c_str(),
engine_instance.serialized_engine().size(), nullptr));
auto raw_engine = engine.get();
std::vector<ExecutionContext> ctx_vec;
if (num_loaded_engine == 0) {
// Restore profiles if there are any. Currently only 1 engine is allowed
// in dynamic mode therefore we call this only for the 0th engine.
// it is a no-op in implicit batch mode.
OP_REQUIRES_OK(ctx, resource->profiles_.RestoreProfiles(
raw_engine, engine_input_shapes.size()));
OP_REQUIRES_OK(ctx, resource->profiles_.CreateExecutionContexts(
raw_engine, &ctx_vec));
} else {
// Multiple engines are only available in static mode. For each engine
// we have only a single execution context.
ctx_vec.push_back(ExecutionContext::Create(raw_engine));
}
resource->cache_.emplace(engine_input_shapes,
absl::make_unique<EngineContext>(
std::move(engine), std::move(ctx_vec)));
++num_loaded_engine;
} while (1);
VLOG(1) << "Loaded " << num_loaded_engine << " TRT engines for op "
<< handle.name() << " on device " << ctx->device()->name()
<< " from file " << filename;
}
private:
// Maximum number of cached engines
int max_cached_engines_;
TF_DISALLOW_COPY_AND_ASSIGN(InitializeTRTResource);
};
REGISTER_KERNEL_BUILDER(Name("InitializeTRTResource")
.Device(DEVICE_GPU)
.HostMemory("resource_handle"),
InitializeTRTResource);
class SerializeTRTResource : public OpKernel {
public:
explicit SerializeTRTResource(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("delete_resource", &delete_resource_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("save_gpu_specific_engines",
&save_gpu_specific_engines_));
}
void Compute(OpKernelContext* ctx) override {
const string& resource_name = ctx->input(0).scalar<tstring>()();
const string& filename = ctx->input(1).scalar<tstring>()();
OP_REQUIRES(ctx, !filename.empty(),
errors::InvalidArgument("filename cannot be empty."));
// Lookup engine cache resource.
TRTEngineCacheResource* resource = nullptr;
OP_REQUIRES(
ctx,
ctx->resource_manager()
->Lookup(std::string(kTfTrtContainerName), resource_name, &resource)
.ok(),
errors::NotFound("TRTEngineCacheResource not yet created"));
core::ScopedUnref unref_me(resource);
// Terminate the calibration if any.
if (resource->calib_ctx_) resource->calib_ctx_->TerminateCalibration();
// Serialize the engines and write them to file.
std::unique_ptr<WritableFile> file;
OP_REQUIRES_OK(ctx, ctx->env()->NewWritableFile(filename, &file));
auto writer = absl::make_unique<io::RecordWriter>(file.get());
int num_serialized_engines = 0;
if (save_gpu_specific_engines_) {
for (const auto& pair : resource->cache_) {
// Ignore engines that failed to build.
const std::unique_ptr<EngineContext>& engine = pair.second;
if (!engine || !engine->cuda_engine) continue;
TRTEngineInstance engine_instance;
// Add input shapes.
const std::vector<TensorShape>& engine_input_shapes = pair.first;
for (const TensorShape& shape : engine_input_shapes) {
shape.AsProto(engine_instance.add_input_shapes());
}
// Add the serialized engine.
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(
engine->cuda_engine->serialize());
engine_instance.set_serialized_engine(engine_data->data(),
engine_data->size());
OP_REQUIRES_OK(
ctx, writer->WriteRecord(engine_instance.SerializeAsString()));
++num_serialized_engines;
}
} else {
VLOG(1) << "TRT Engines are not serialized for op: " << resource_name;
}
VLOG(1) << "Serialized " << num_serialized_engines << " TRT engines for op "
<< resource_name << " on device " << ctx->device()->name()
<< " to file " << filename;
if (delete_resource_) {
VLOG(1) << "Destroying TRT engine cache resource for op " << resource_name
<< " on device " << ctx->device()->name();
OP_REQUIRES_OK(ctx,
ctx->resource_manager()->Delete<TRTEngineCacheResource>(
std::string(kTfTrtContainerName), resource_name));
}
}
private:
bool delete_resource_ = false;
bool save_gpu_specific_engines_ = true;
TF_DISALLOW_COPY_AND_ASSIGN(SerializeTRTResource);
};
REGISTER_KERNEL_BUILDER(Name("SerializeTRTResource").Device(DEVICE_GPU),
SerializeTRTResource);
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT