| /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| #define EIGEN_USE_THREADS |
| |
| #include <stddef.h> |
| #include <atomic> |
| #include <cmath> |
| #include <functional> |
| #include <limits> |
| #include <string> |
| #include <unordered_set> |
| |
| #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
| #include "tensorflow/core/framework/device_base.h" |
| #include "tensorflow/core/framework/kernel_def_builder.h" |
| #include "tensorflow/core/framework/op.h" |
| #include "tensorflow/core/framework/op_def_builder.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/register_types.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/tensor_shape.h" |
| #include "tensorflow/core/framework/tensor_types.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/kernels/gpu_utils.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/lib/core/stringpiece.h" |
| #include "tensorflow/core/lib/gtl/inlined_vector.h" |
| #include "tensorflow/core/lib/hash/hash.h" |
| #include "tensorflow/core/lib/strings/stringprintf.h" |
| #include "tensorflow/core/platform/fingerprint.h" |
| #include "tensorflow/core/platform/mutex.h" |
| #include "tensorflow/core/platform/types.h" |
| #include "tensorflow/core/util/env_var.h" |
| #include "tensorflow/core/util/use_cudnn.h" |
| |
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| #include "tensorflow/core/platform/stream_executor.h" |
| #include "tensorflow/core/util/stream_executor_util.h" |
| #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| /* |
| * This module implements ops that fuse a multi-layer multi-step RNN/LSTM model |
| * using the underlying Cudnn library. |
| * |
| * Cudnn RNN library exposes an opaque parameter buffer with unknown layout and |
| * format. And it is very likely that if saved, they cannot be used across |
| * different GPUs. So users need to first query the size of the opaque |
| * parameter buffer, and convert it to and from its canonical forms. But each |
| * actual training step is carried out with the parameter buffer. |
| * |
| * Similar to many other ops, the forward op has two flavors: training and |
| * inference. When training is specified, additional data in reserve_space will |
| * be produced for the backward pass. So there is a performance penalty. |
| * |
| * In addition to the actual data and reserve_space, Cudnn also needs more |
| * memory as temporary workspace. The memory management to and from |
| * stream-executor is done through ScratchAllocator. In general, |
| * stream-executor is responsible for creating the memory of proper size. And |
| * TensorFlow is responsible for making sure the memory is alive long enough |
| * and recycles afterwards. |
| * |
| */ |
| namespace tensorflow { |
| |
| using CPUDevice = Eigen::ThreadPoolDevice; |
| |
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| using GPUDevice = Eigen::GpuDevice; |
| using se::Stream; |
| using se::StreamExecutor; |
| using se::dnn::RnnDescriptor; |
| |
| template <typename Device, typename T, typename Index> |
| class CudnnRNNParamsSizeOp; |
| |
| template <typename Device, typename T> |
| class CudnnRNNParamsToCanonical; |
| |
| template <typename Device, typename T> |
| class CudnnRNNCanonicalToParams; |
| |
| template <typename Device, typename T> |
| class CudnnRNNForwardOp; |
| |
| template <typename Device, typename T> |
| class CudnnRNNBackwardOp; |
| |
| template <typename Device, typename T> |
| class CudnnRNNForwardOpV2; |
| |
| template <typename Device, typename T> |
| class CudnnRNNBackwardOpV2; |
| |
| template <typename Device, typename T> |
| class CudnnRNNForwardOpV3; |
| |
| template <typename Device, typename T> |
| class CudnnRNNBackwardOpV3; |
| |
| enum class TFRNNInputMode { |
| kRNNLinearInput = 0, |
| kRNNSkipInput = 1, |
| kAutoSelect = 9999999 |
| }; |
| |
| namespace { |
| using se::DeviceMemory; |
| using se::DeviceMemoryBase; |
| using se::ScratchAllocator; |
| using se::dnn::AlgorithmConfig; |
| using se::dnn::AlgorithmDesc; |
| using se::dnn::ProfileResult; |
| using se::dnn::RnnDirectionMode; |
| using se::dnn::RnnInputMode; |
| using se::dnn::RnnMode; |
| using se::dnn::RnnSequenceTensorDescriptor; |
| using se::dnn::RnnStateTensorDescriptor; |
| using se::dnn::ToDataType; |
| using se::port::StatusOr; |
| |
| uint64 HashList(const std::vector<int>& list) { |
| if (list.empty()) { |
| return 0; |
| } |
| uint64 hash_code = list[0]; |
| for (int i = 1; i < list.size(); i++) { |
| hash_code = Hash64Combine(hash_code, list[i]); |
| } |
| return hash_code; |
| } |
| |
| // Encapsulate all the shape information that is used in both forward and |
| // backward rnn operations. |
| class CudnnRnnParameters { |
| public: |
| CudnnRnnParameters(int num_layers, int input_size, int num_units, |
| int max_seq_length, int batch_size, int dir_count, |
| bool has_dropout, bool is_training, RnnMode rnn_mode, |
| TFRNNInputMode rnn_input_mode, DataType dtype) |
| : num_layers_(num_layers), |
| input_size_(input_size), |
| num_units_(num_units), |
| seq_length_(max_seq_length), |
| batch_size_(batch_size), |
| dir_count_(dir_count), |
| has_dropout_(has_dropout), |
| is_training_(is_training), |
| rnn_mode_(rnn_mode), |
| rnn_input_mode_(rnn_input_mode), |
| dtype_(dtype) { |
| hash_code_ = |
| HashList({num_layers, input_size, num_units, max_seq_length, batch_size, |
| dir_count, static_cast<int>(has_dropout), |
| static_cast<int>(is_training), static_cast<int>(rnn_mode), |
| static_cast<int>(rnn_input_mode), dtype}); |
| } |
| |
| bool operator==(const CudnnRnnParameters& other) const { |
| return this->get_data_as_tuple() == other.get_data_as_tuple(); |
| } |
| |
| bool operator!=(const CudnnRnnParameters& other) const { |
| return !(*this == other); |
| } |
| uint64 hash() const { return hash_code_; } |
| |
| string ToString() const { |
| std::vector<string> fields = { |
| std::to_string(num_layers_), |
| std::to_string(input_size_), |
| std::to_string(num_units_), |
| std::to_string(seq_length_), |
| std::to_string(batch_size_), |
| std::to_string(dir_count_), |
| std::to_string(has_dropout_), |
| std::to_string(is_training_), |
| std::to_string(static_cast<int>(rnn_mode_)), |
| std::to_string(static_cast<int>(rnn_input_mode_)), |
| std::to_string(static_cast<int>(dtype_))}; |
| return absl::StrJoin(fields, ", "); |
| } |
| |
| private: |
| using ParameterDataType = std::tuple<int, int, int, int, int, int, bool, bool, |
| RnnMode, TFRNNInputMode, DataType>; |
| |
| ParameterDataType get_data_as_tuple() const { |
| return std::make_tuple(num_layers_, input_size_, num_units_, seq_length_, |
| batch_size_, dir_count_, has_dropout_, is_training_, |
| rnn_mode_, rnn_input_mode_, dtype_); |
| } |
| |
| const int num_layers_; |
| const int input_size_; |
| const int num_units_; |
| const int seq_length_; |
| const int batch_size_; |
| const int dir_count_; |
| const bool has_dropout_; |
| const bool is_training_; |
| const RnnMode rnn_mode_; |
| const TFRNNInputMode rnn_input_mode_; |
| const DataType dtype_; |
| uint64 hash_code_; |
| }; |
| |
| struct RnnAutoTuneGroup { |
| static string name() { return "Rnn"; } |
| }; |
| |
| using AutoTuneRnnConfigMap = |
| AutoTuneSingleton<RnnAutoTuneGroup, CudnnRnnParameters, AlgorithmConfig>; |
| |
| Status ParseRNNMode(const string& str, RnnMode* rnn_mode) { |
| if (str == "rnn_relu") { |
| *rnn_mode = RnnMode::kRnnRelu; |
| return Status::OK(); |
| } else if (str == "rnn_tanh") { |
| *rnn_mode = RnnMode::kRnnTanh; |
| return Status::OK(); |
| } else if (str == "lstm") { |
| *rnn_mode = RnnMode::kRnnLstm; |
| return Status::OK(); |
| } else if (str == "gru") { |
| *rnn_mode = RnnMode::kRnnGru; |
| return Status::OK(); |
| } |
| return errors::InvalidArgument("Invalid RNN mode: ", str); |
| } |
| |
| Status ParseTFRNNInputMode(const string& str, TFRNNInputMode* rnn_input_mode) { |
| if (str == "linear_input") { |
| *rnn_input_mode = TFRNNInputMode::kRNNLinearInput; |
| return Status::OK(); |
| } else if (str == "skip_input") { |
| *rnn_input_mode = TFRNNInputMode::kRNNSkipInput; |
| return Status::OK(); |
| } else if (str == "auto_select") { |
| *rnn_input_mode = TFRNNInputMode::kAutoSelect; |
| return Status::OK(); |
| } |
| return errors::InvalidArgument("Invalid RNN input mode: ", str); |
| } |
| |
| Status ParseRNNDirectionMode(const string& str, |
| RnnDirectionMode* rnn_dir_mode) { |
| if (str == "unidirectional") { |
| *rnn_dir_mode = RnnDirectionMode::kRnnUnidirectional; |
| return Status::OK(); |
| } else if (str == "bidirectional") { |
| *rnn_dir_mode = RnnDirectionMode::kRnnBidirectional; |
| return Status::OK(); |
| } |
| return errors::InvalidArgument("Invalid RNN direction mode: ", str); |
| } |
| |
| Status ToRNNInputMode(TFRNNInputMode tf_input_mode, int num_units, |
| int input_size, RnnInputMode* input_mode) { |
| switch (tf_input_mode) { |
| case TFRNNInputMode::kRNNLinearInput: |
| *input_mode = RnnInputMode::kRnnLinearSkip; |
| break; |
| case TFRNNInputMode::kRNNSkipInput: |
| *input_mode = RnnInputMode::kRnnSkipInput; |
| break; |
| case TFRNNInputMode::kAutoSelect: |
| *input_mode = (input_size == num_units) ? RnnInputMode::kRnnSkipInput |
| : RnnInputMode::kRnnLinearSkip; |
| break; |
| default: |
| return errors::InvalidArgument("Invalid TF input mode: ", |
| static_cast<int>(tf_input_mode)); |
| } |
| return Status::OK(); |
| } |
| |
| // TODO(zhengxq): Merge those into stream_executor_util.h. |
| template <typename T> |
| const DeviceMemory<T> AsDeviceMemory(const Tensor* tensor) { |
| return DeviceMemory<T>::MakeFromByteSize( |
| const_cast<T*>(tensor->template flat<T>().data()), |
| tensor->template flat<T>().size() * sizeof(T)); |
| } |
| |
| template <typename T> |
| DeviceMemory<T> AsDeviceMemory(Tensor* tensor) { |
| return DeviceMemory<T>::MakeFromByteSize( |
| tensor->template flat<T>().data(), |
| tensor->template flat<T>().size() * sizeof(T)); |
| } |
| |
| template <typename U, typename T> |
| DeviceMemory<U> CastDeviceMemory(Tensor* tensor) { |
| return DeviceMemory<U>::MakeFromByteSize( |
| tensor->template flat<T>().data(), |
| tensor->template flat<T>().size() * sizeof(T)); |
| } |
| |
| DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory, |
| int64 offset, int64 size) { |
| const void* base_ptr = device_memory.opaque(); |
| void* offset_ptr = |
| const_cast<char*>(reinterpret_cast<const char*>(base_ptr) + offset); |
| CHECK(offset + size <= device_memory.size()) |
| << "The slice is not within the region of DeviceMemory."; |
| return DeviceMemoryBase(offset_ptr, size); |
| } |
| |
| inline Status FromExecutorStatus(const se::port::Status& s) { |
| return s.ok() ? Status::OK() |
| : Status(static_cast<error::Code>(static_cast<int>(s.code())), |
| s.error_message()); |
| } |
| |
| template <typename T> |
| inline Status FromExecutorStatus(const se::port::StatusOr<T>& s) { |
| return FromExecutorStatus(s.status()); |
| } |
| |
| inline se::port::Status ToExecutorStatus(const Status& s) { |
| return s.ok() ? se::port::Status::OK() |
| : se::port::Status(static_cast<se::port::error::Code>( |
| static_cast<int>(s.code())), |
| s.error_message()); |
| } |
| |
| template <typename> |
| struct ToTFDataType; |
| |
| template <> |
| struct ToTFDataType<Eigen::half> : std::integral_constant<DataType, DT_HALF> {}; |
| |
| template <> |
| struct ToTFDataType<float> : std::integral_constant<DataType, DT_FLOAT> {}; |
| |
| template <> |
| struct ToTFDataType<double> : std::integral_constant<DataType, DT_DOUBLE> {}; |
| |
| template <> |
| struct ToTFDataType<uint8> : std::integral_constant<DataType, DT_UINT8> {}; |
| |
| // A helper to allocate temporary scratch memory for Cudnn RNN models. It |
| // takes the ownership of the underlying memory. The expectation is that the |
| // memory should be alive for the span of the Cudnn RNN itself. |
| template <typename T> |
| class CudnnRnnAllocatorInTemp : public ScratchAllocator { |
| public: |
| ~CudnnRnnAllocatorInTemp() override = default; |
| |
| explicit CudnnRnnAllocatorInTemp(OpKernelContext* context) |
| : context_(context) {} |
| int64 GetMemoryLimitInBytes() override { |
| return std::numeric_limits<int64>::max(); |
| } |
| |
| StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override { |
| Tensor temporary_memory; |
| const DataType tf_data_type = ToTFDataType<T>::value; |
| int64 allocate_count = |
| Eigen::divup(byte_size, static_cast<int64>(sizeof(T))); |
| Status allocation_status(context_->allocate_temp( |
| tf_data_type, TensorShape({allocate_count}), &temporary_memory)); |
| if (!allocation_status.ok()) { |
| return ToExecutorStatus(allocation_status); |
| } |
| // Hold the reference of the allocated tensors until the end of the |
| // allocator. |
| allocated_tensors_.push_back(temporary_memory); |
| total_byte_size_ += byte_size; |
| return DeviceMemory<uint8>::MakeFromByteSize( |
| temporary_memory.template flat<T>().data(), |
| temporary_memory.template flat<T>().size() * sizeof(T)); |
| } |
| |
| int64 TotalByteSize() const { return total_byte_size_; } |
| |
| Tensor get_allocated_tensor(int index) const { |
| return allocated_tensors_[index]; |
| } |
| |
| private: |
| int64 total_byte_size_ = 0; |
| OpKernelContext* context_; // not owned |
| std::vector<Tensor> allocated_tensors_; |
| }; |
| |
| // A helper to allocate memory for Cudnn RNN models as a kernel output. It is |
| // used by forward pass kernel to feed the output to the backward pass. |
| // The memory is expected to live long enough after the backward pass is |
| // finished. |
| template <typename T> |
| class CudnnRnnAllocatorInOutput : public ScratchAllocator { |
| public: |
| ~CudnnRnnAllocatorInOutput() override {} |
| CudnnRnnAllocatorInOutput(OpKernelContext* context, int output_index) |
| : context_(context), output_index_(output_index) {} |
| int64 GetMemoryLimitInBytes() override { |
| return std::numeric_limits<int64>::max(); |
| } |
| StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override { |
| CHECK(total_byte_size_ == 0) |
| << "Reserve space allocator can only be called once"; |
| int64 allocate_count = |
| Eigen::divup(byte_size, static_cast<int64>(sizeof(T))); |
| |
| Tensor* temporary_memory = nullptr; |
| Status allocation_status(context_->allocate_output( |
| output_index_, TensorShape({allocate_count}), &temporary_memory)); |
| if (!allocation_status.ok()) { |
| return ToExecutorStatus(allocation_status); |
| } |
| total_byte_size_ += byte_size; |
| auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize( |
| temporary_memory->template flat<T>().data(), |
| temporary_memory->template flat<T>().size() * sizeof(T)); |
| return StatusOr<DeviceMemory<uint8>>(memory_uint8); |
| } |
| int64 TotalByteSize() { return total_byte_size_; } |
| |
| private: |
| int64 total_byte_size_ = 0; |
| OpKernelContext* context_; // not owned |
| int output_index_; |
| }; |
| |
| // A helper to allocate persistent memory for Cudnn RNN models, which is |
| // expected to live between kernel invocations. |
| // This class is not thread-safe. |
| class CudnnRNNPersistentSpaceAllocator : public ScratchAllocator { |
| public: |
| explicit CudnnRNNPersistentSpaceAllocator(OpKernelContext* context) |
| : context_(context) {} |
| |
| ~CudnnRNNPersistentSpaceAllocator() override {} |
| |
| int64 GetMemoryLimitInBytes() override { |
| return std::numeric_limits<int64>::max(); |
| } |
| |
| StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override { |
| if (total_byte_size_ != 0) { |
| return Status(error::FAILED_PRECONDITION, |
| "Persistent space allocator can only be called once"); |
| } |
| |
| Status allocation_status = context_->allocate_persistent( |
| DT_UINT8, TensorShape({byte_size}), &handle_, nullptr); |
| if (!allocation_status.ok()) { |
| return ToExecutorStatus(allocation_status); |
| } |
| total_byte_size_ += byte_size; |
| return AsDeviceMemory<uint8>(handle_.AccessTensor(context_)); |
| } |
| int64 TotalByteSize() { return total_byte_size_; } |
| |
| private: |
| int64 total_byte_size_ = 0; |
| PersistentTensor handle_; |
| OpKernelContext* context_; // not owned |
| }; |
| |
| struct CudnnModelTypes { |
| RnnMode rnn_mode; |
| TFRNNInputMode rnn_input_mode; |
| RnnDirectionMode rnn_direction_mode; |
| bool HasInputC() const { |
| // For Cudnn 5.0, only LSTM has input-c. All other models use only |
| // input-h. |
| return rnn_mode == RnnMode::kRnnLstm; |
| } |
| |
| string DebugString() const { |
| return strings::Printf( |
| "[rnn_mode, rnn_input_mode, rnn_direction_mode]: %d, %d, %d ", |
| static_cast<int>(rnn_mode), static_cast<int>(rnn_input_mode), |
| static_cast<int>(rnn_direction_mode)); |
| } |
| }; |
| |
| // A helper class that collects the shapes to describe a RNN model. |
| struct CudnnRnnModelShapes { |
| int num_layers; |
| int input_size; |
| int num_units; |
| int dir_count; |
| int max_seq_length; |
| int batch_size; |
| int cell_num_units = 0; |
| TensorShape input_shape; |
| TensorShape output_shape; |
| TensorShape hidden_state_shape; |
| TensorShape cell_state_shape; |
| // At present only fields related to cached RnnDescriptor are concerned. |
| bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const { |
| return num_layers == rhs.num_layers && input_size == rhs.input_size && |
| num_units == rhs.num_units && dir_count == rhs.dir_count && |
| cell_num_units == rhs.cell_num_units; |
| } |
| string DebugString() const { |
| return strings::Printf( |
| "[num_layers, input_size, num_units, dir_count, max_seq_length, " |
| "batch_size, cell_num_units]: [%d, %d, %d, %d, %d, %d, %d] ", |
| num_layers, input_size, num_units, dir_count, max_seq_length, |
| batch_size, cell_num_units); |
| } |
| }; |
| |
| // Utility class for using CudnnRnnConfig and AlgorithmDesc pair a hash table |
| // key. |
| struct CudnnRnnConfigHasher { |
| uint64 operator()( |
| const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& |
| to_hash) const { |
| auto& shapes = to_hash.first; |
| auto& algo_desc = to_hash.second; |
| |
| uint64 hash = |
| HashList({shapes.num_layers, shapes.input_size, shapes.num_units, |
| shapes.dir_count, shapes.batch_size}); |
| if (algo_desc.has_value()) { |
| hash = Hash64Combine(hash, algo_desc->hash()); |
| } |
| return hash; |
| } |
| }; |
| |
| // Utility class for using CudnnRnnModelShapes and AlgorithmDesc pair as a hash |
| // table key. |
| struct CudnnRnnConfigComparator { |
| bool operator()( |
| const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& lhs, |
| const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& rhs) |
| const { |
| return lhs.first.IsCompatibleWith(rhs.first) && lhs.second == rhs.second; |
| } |
| }; |
| |
| // Pointers to RNN scratch space for a specific set of shape parameters (used as |
| // a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp). |
| struct RnnScratchSpace { |
| std::unique_ptr<RnnDescriptor> rnn_desc; |
| std::unique_ptr<CudnnRNNPersistentSpaceAllocator> dropout_state_allocator; |
| }; |
| |
| // Extract and checks the forward input tensors, parameters, and shapes from the |
| // OpKernelContext. |
| Status ExtractForwardInput(OpKernelContext* context, |
| const CudnnModelTypes& model_types, bool time_major, |
| const Tensor** input, const Tensor** input_h, |
| const Tensor** input_c, const Tensor** params, |
| const int num_proj, |
| CudnnRnnModelShapes* model_shapes) { |
| TF_RETURN_IF_ERROR(context->input("input", input)); |
| TF_RETURN_IF_ERROR(context->input("input_h", input_h)); |
| if (model_types.HasInputC()) { |
| TF_RETURN_IF_ERROR(context->input("input_c", input_c)); |
| } |
| TF_RETURN_IF_ERROR(context->input("params", params)); |
| |
| if ((*input)->dims() != 3) { |
| return errors::InvalidArgument("RNN input must be a 3-D vector."); |
| } |
| if (time_major) { |
| model_shapes->max_seq_length = (*input)->dim_size(0); |
| model_shapes->batch_size = (*input)->dim_size(1); |
| } else { |
| model_shapes->max_seq_length = (*input)->dim_size(1); |
| model_shapes->batch_size = (*input)->dim_size(0); |
| } |
| model_shapes->input_size = (*input)->dim_size(2); |
| model_shapes->input_shape = (*input)->shape(); |
| model_shapes->dir_count = |
| (model_types.rnn_direction_mode == RnnDirectionMode::kRnnBidirectional) |
| ? 2 |
| : 1; |
| |
| if ((*input_h)->dims() != 3) { |
| return errors::InvalidArgument("RNN input_h must be a 3-D vector."); |
| } |
| if (time_major) { |
| model_shapes->num_layers = |
| (*input_h)->dim_size(0) / model_shapes->dir_count; |
| } else { |
| model_shapes->num_layers = |
| (*input_h)->dim_size(1) / model_shapes->dir_count; |
| } |
| model_shapes->num_units = (*input_h)->dim_size(2); |
| |
| if (time_major) { |
| model_shapes->hidden_state_shape = |
| TensorShape({model_shapes->dir_count * model_shapes->num_layers, |
| model_shapes->batch_size, model_shapes->num_units}); |
| } else { |
| model_shapes->hidden_state_shape = |
| TensorShape({model_shapes->batch_size, |
| model_shapes->dir_count * model_shapes->num_layers, |
| model_shapes->num_units}); |
| } |
| if ((*input_h)->shape() != model_shapes->hidden_state_shape) { |
| return errors::InvalidArgument( |
| "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ", |
| model_shapes->hidden_state_shape.DebugString()); |
| } |
| if (model_types.HasInputC()) { |
| model_shapes->cell_num_units = (*input_c)->dim_size(2); |
| if (time_major) { |
| model_shapes->cell_state_shape = |
| TensorShape({model_shapes->dir_count * model_shapes->num_layers, |
| model_shapes->batch_size, model_shapes->cell_num_units}); |
| } else { |
| model_shapes->cell_state_shape = |
| TensorShape({model_shapes->batch_size, |
| model_shapes->dir_count * model_shapes->num_layers, |
| model_shapes->cell_num_units}); |
| } |
| if (num_proj == 0) { |
| if ((*input_h)->shape() != (*input_c)->shape()) { |
| return errors::InvalidArgument( |
| "input_h and input_c must have the same shape w/o projection: ", |
| (*input_h)->shape().DebugString(), " ", |
| (*input_c)->shape().DebugString()); |
| } |
| } else { |
| if ((*input_h)->dim_size(2) > (*input_c)->dim_size(2) || |
| num_proj != (*input_h)->dim_size(2) || |
| (*input_h)->dim_size(0) != (*input_c)->dim_size(0) || |
| (*input_h)->dim_size(1) != (*input_c)->dim_size(1)) { |
| return errors::InvalidArgument( |
| "Invalid input_h and input_c w/ projection size: ", num_proj, " ", |
| (*input_h)->shape().DebugString(), " ", |
| (*input_c)->shape().DebugString()); |
| } |
| } |
| } else { |
| // dummy cell_state_shape TODO(kaixih): remove the time_major branch |
| if (time_major) { |
| model_shapes->cell_state_shape = |
| TensorShape({model_shapes->dir_count * model_shapes->num_layers, |
| model_shapes->batch_size, model_shapes->num_units}); |
| } else { |
| model_shapes->cell_state_shape = |
| TensorShape({model_shapes->batch_size, |
| model_shapes->dir_count * model_shapes->num_layers, |
| model_shapes->num_units}); |
| } |
| model_shapes->cell_num_units = 0; |
| } |
| if (time_major) { |
| model_shapes->output_shape = |
| TensorShape({model_shapes->max_seq_length, model_shapes->batch_size, |
| model_shapes->dir_count * model_shapes->num_units}); |
| } else { |
| model_shapes->output_shape = |
| TensorShape({model_shapes->batch_size, model_shapes->max_seq_length, |
| model_shapes->dir_count * model_shapes->num_units}); |
| } |
| return Status::OK(); |
| } |
| |
| // Overloaded function to process the sequence_lengths |
| Status ExtractForwardInput(OpKernelContext* context, |
| const CudnnModelTypes& model_types, bool time_major, |
| const Tensor** input, const Tensor** input_h, |
| const Tensor** input_c, const Tensor** params, |
| const Tensor** sequence_lengths, const int num_proj, |
| CudnnRnnModelShapes* model_shapes) { |
| TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths)); |
| return ExtractForwardInput(context, model_types, time_major, input, input_h, |
| input_c, params, num_proj, model_shapes); |
| } |
| |
| template <typename T> |
| Status CreateForwardAndBackwardIODescriptors( |
| OpKernelContext* context, const CudnnRnnModelShapes& model_shapes, |
| std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc, |
| std::unique_ptr<RnnStateTensorDescriptor>* h_state_desc, |
| std::unique_ptr<RnnStateTensorDescriptor>* c_state_desc, |
| std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc, |
| const absl::Span<const int>& seq_lengths, bool time_major) { |
| StreamExecutor* executor = context->op_device_context()->stream()->parent(); |
| se::dnn::DataType data_type = ToDataType<T>::value; |
| |
| const TensorShape& input_shape = model_shapes.input_shape; |
| const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; |
| const TensorShape& cell_state_shape = model_shapes.cell_state_shape; |
| const TensorShape& output_shape = model_shapes.output_shape; |
| |
| DCHECK_EQ(input_shape.dims(), 3); |
| if (seq_lengths.data() != nullptr) { |
| if (time_major) { |
| auto input_desc_s = executor->createRnnSequenceTensorDescriptor( |
| input_shape.dim_size(0), input_shape.dim_size(1), |
| input_shape.dim_size(2), seq_lengths, time_major, data_type); |
| TF_RETURN_IF_ERROR(input_desc_s.status()); |
| *input_desc = input_desc_s.ConsumeValueOrDie(); |
| } else { |
| auto input_desc_s = executor->createRnnSequenceTensorDescriptor( |
| input_shape.dim_size(1), input_shape.dim_size(0), |
| input_shape.dim_size(2), seq_lengths, time_major, data_type); |
| TF_RETURN_IF_ERROR(input_desc_s.status()); |
| *input_desc = input_desc_s.ConsumeValueOrDie(); |
| } |
| } else { |
| auto input_desc_s = executor->createRnnSequenceTensorDescriptor( |
| input_shape.dim_size(0), input_shape.dim_size(1), |
| input_shape.dim_size(2), data_type); |
| TF_RETURN_IF_ERROR(input_desc_s.status()); |
| *input_desc = input_desc_s.ConsumeValueOrDie(); |
| } |
| |
| DCHECK_EQ(hidden_state_shape.dims(), 3); |
| if (time_major) { |
| auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor( |
| hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1), |
| hidden_state_shape.dim_size(2), data_type); |
| TF_RETURN_IF_ERROR(hidden_state_desc_s.status()); |
| *h_state_desc = hidden_state_desc_s.ConsumeValueOrDie(); |
| } else { |
| auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor( |
| hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0), |
| hidden_state_shape.dim_size(2), data_type); |
| TF_RETURN_IF_ERROR(hidden_state_desc_s.status()); |
| *h_state_desc = hidden_state_desc_s.ConsumeValueOrDie(); |
| } |
| |
| DCHECK_EQ(cell_state_shape.dims(), 3); |
| if (time_major) { |
| auto cell_state_desc_s = executor->createRnnStateTensorDescriptor( |
| cell_state_shape.dim_size(0), cell_state_shape.dim_size(1), |
| cell_state_shape.dim_size(2), data_type); |
| TF_RETURN_IF_ERROR(cell_state_desc_s.status()); |
| *c_state_desc = cell_state_desc_s.ConsumeValueOrDie(); |
| } else { |
| auto cell_state_desc_s = executor->createRnnStateTensorDescriptor( |
| cell_state_shape.dim_size(1), cell_state_shape.dim_size(0), |
| cell_state_shape.dim_size(2), data_type); |
| TF_RETURN_IF_ERROR(cell_state_desc_s.status()); |
| *c_state_desc = cell_state_desc_s.ConsumeValueOrDie(); |
| } |
| |
| DCHECK_EQ(output_shape.dims(), 3); |
| if (seq_lengths.data() != nullptr) { |
| if (time_major) { |
| auto output_desc_s = executor->createRnnSequenceTensorDescriptor( |
| output_shape.dim_size(0), output_shape.dim_size(1), |
| output_shape.dim_size(2), seq_lengths, time_major, data_type); |
| TF_RETURN_IF_ERROR(output_desc_s.status()); |
| *output_desc = output_desc_s.ConsumeValueOrDie(); |
| } else { |
| auto output_desc_s = executor->createRnnSequenceTensorDescriptor( |
| output_shape.dim_size(1), output_shape.dim_size(0), |
| output_shape.dim_size(2), seq_lengths, time_major, data_type); |
| TF_RETURN_IF_ERROR(output_desc_s.status()); |
| *output_desc = output_desc_s.ConsumeValueOrDie(); |
| } |
| } else { |
| auto output_desc_s = executor->createRnnSequenceTensorDescriptor( |
| output_shape.dim_size(0), output_shape.dim_size(1), |
| output_shape.dim_size(2), data_type); |
| TF_RETURN_IF_ERROR(output_desc_s.status()); |
| *output_desc = output_desc_s.ConsumeValueOrDie(); |
| } |
| |
| return Status::OK(); |
| } |
| |
| template <typename T> |
| Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc, |
| const CudnnModelTypes& model_types, |
| const CudnnRnnModelShapes& model_shapes, |
| /* forward inputs */ |
| const Tensor* input, const Tensor* input_h, |
| const Tensor* input_c, const Tensor* params, |
| const bool is_training, |
| /* forward outputs, outputs of the function */ |
| Tensor* output, Tensor* output_h, Tensor* output_c, |
| const Tensor* sequence_lengths, bool time_major, |
| ScratchAllocator* reserve_space_allocator, |
| ScratchAllocator* workspace_allocator, |
| ProfileResult* output_profile_result) { |
| std::unique_ptr<RnnSequenceTensorDescriptor> input_desc; |
| std::unique_ptr<RnnStateTensorDescriptor> h_state_desc; |
| std::unique_ptr<RnnStateTensorDescriptor> c_state_desc; |
| std::unique_ptr<RnnSequenceTensorDescriptor> output_desc; |
| |
| absl::Span<const int> seq_lengths; |
| if (sequence_lengths != nullptr) { |
| seq_lengths = absl::Span<const int>( |
| sequence_lengths->template flat<int>().data(), model_shapes.batch_size); |
| } |
| TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>( |
| context, model_shapes, &input_desc, &h_state_desc, &c_state_desc, |
| &output_desc, seq_lengths, time_major)); |
| |
| auto input_data = AsDeviceMemory<T>(input); |
| auto input_h_data = AsDeviceMemory<T>(input_h); |
| DeviceMemory<T> input_c_data; |
| if (model_types.HasInputC()) { |
| input_c_data = AsDeviceMemory<T>(input_c); |
| } |
| |
| auto params_data = AsDeviceMemory<T>(params); |
| auto output_data = AsDeviceMemory<T>(output); |
| auto output_h_data = AsDeviceMemory<T>(output_h); |
| DeviceMemory<T> output_c_data; |
| if (model_types.HasInputC()) { |
| output_c_data = AsDeviceMemory<T>(output_c); |
| } |
| |
| Stream* stream = context->op_device_context()->stream(); |
| bool launch_success = |
| stream |
| ->ThenRnnForward(rnn_desc, *input_desc, input_data, *h_state_desc, |
| input_h_data, *c_state_desc, input_c_data, |
| params_data, *output_desc, &output_data, |
| *h_state_desc, &output_h_data, *c_state_desc, |
| &output_c_data, is_training, reserve_space_allocator, |
| workspace_allocator, output_profile_result) |
| .ok(); |
| return launch_success |
| ? Status::OK() |
| : errors::Internal( |
| "Failed to call ThenRnnForward with model config: ", |
| model_types.DebugString(), ", ", model_shapes.DebugString()); |
| } |
| |
| template <typename T> |
| Status DoBackward( |
| OpKernelContext* context, const RnnDescriptor& rnn_desc, |
| const CudnnModelTypes& model_types, const CudnnRnnModelShapes& model_shapes, |
| /* forward inputs */ |
| const Tensor* input, const Tensor* input_h, const Tensor* input_c, |
| const Tensor* params, |
| /* forward outputs */ |
| const Tensor* output, const Tensor* output_h, const Tensor* output_c, |
| /* backprop inputs */ |
| const Tensor* output_backprop, const Tensor* output_h_backprop, |
| const Tensor* output_c_backprop, const Tensor* reserve_space, |
| /* backprop outputs, output of the function */ |
| Tensor* input_backprop, Tensor* input_h_backprop, Tensor* input_c_backprop, |
| Tensor* params_backprop, const Tensor* sequence_lengths, bool time_major, |
| ScratchAllocator* workspace_allocator, |
| ProfileResult* output_profile_result) { |
| std::unique_ptr<RnnSequenceTensorDescriptor> input_desc; |
| std::unique_ptr<RnnStateTensorDescriptor> h_state_desc; |
| std::unique_ptr<RnnStateTensorDescriptor> c_state_desc; |
| std::unique_ptr<RnnSequenceTensorDescriptor> output_desc; |
| |
| absl::Span<const int> seq_lengths; |
| if (sequence_lengths != nullptr) { |
| seq_lengths = absl::Span<const int>( |
| sequence_lengths->template flat<int>().data(), model_shapes.batch_size); |
| } |
| TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>( |
| context, model_shapes, &input_desc, &h_state_desc, &c_state_desc, |
| &output_desc, seq_lengths, time_major)); |
| |
| auto input_data = AsDeviceMemory<T>(input); |
| auto input_h_data = AsDeviceMemory<T>(input_h); |
| DeviceMemory<T> input_c_data; |
| if (model_types.HasInputC()) { |
| input_c_data = AsDeviceMemory<T>(input_c); |
| } |
| auto params_data = AsDeviceMemory<T>(params); |
| auto output_data = AsDeviceMemory<T>(output); |
| auto output_h_data = AsDeviceMemory<T>(output_h); |
| DeviceMemory<T> output_c_data; |
| if (model_types.HasInputC()) { |
| output_c_data = AsDeviceMemory<T>(output_c); |
| } |
| auto output_backprop_data = AsDeviceMemory<T>(output_backprop); |
| auto output_h_backprop_data = AsDeviceMemory<T>(output_h_backprop); |
| DeviceMemory<T> output_c_backprop_data; |
| if (model_types.HasInputC()) { |
| output_c_backprop_data = AsDeviceMemory<T>(output_c_backprop); |
| } |
| auto input_backprop_data = AsDeviceMemory<T>(input_backprop); |
| auto input_h_backprop_data = AsDeviceMemory<T>(input_h_backprop); |
| DeviceMemory<T> input_c_backprop_data; |
| if (model_types.HasInputC()) { |
| input_c_backprop_data = AsDeviceMemory<T>(input_c_backprop); |
| } |
| auto params_backprop_data = AsDeviceMemory<T>(params_backprop); |
| auto reserve_space_uint8 = |
| CastDeviceMemory<uint8, T>(const_cast<Tensor*>(reserve_space)); |
| |
| // Creates a memory callback for the workspace. The memory lives to the end |
| // of this kernel calls. |
| Stream* stream = context->op_device_context()->stream(); |
| bool launch_success = |
| stream |
| ->ThenRnnBackward( |
| rnn_desc, *input_desc, input_data, *h_state_desc, input_h_data, |
| *c_state_desc, input_c_data, params_data, *output_desc, |
| output_data, *h_state_desc, output_h_data, *c_state_desc, |
| output_c_data, output_backprop_data, output_h_backprop_data, |
| output_c_backprop_data, &input_backprop_data, |
| &input_h_backprop_data, &input_c_backprop_data, |
| ¶ms_backprop_data, &reserve_space_uint8, workspace_allocator, |
| output_profile_result) |
| .ok(); |
| return launch_success |
| ? Status::OK() |
| : errors::Internal( |
| "Failed to call ThenRnnBackward with model config: ", |
| model_types.DebugString(), ", ", model_shapes.DebugString()); |
| } |
| |
| template <typename T> |
| void RestoreParams(const OpInputList params_input, |
| const std::vector<RnnDescriptor::ParamsRegion>& params, |
| DeviceMemoryBase* data_dst, Stream* stream) { |
| int num_params = params.size(); |
| CHECK(params_input.size() == num_params) |
| << "Number of params mismatch. Expected " << params_input.size() |
| << ", got " << num_params; |
| for (int i = 0; i < params.size(); i++) { |
| int64 size_in_bytes = params[i].size; |
| int64 size = size_in_bytes / sizeof(T); |
| CHECK(size == params_input[i].NumElements()) |
| << "Params size mismatch. Expected " << size << ", got " |
| << params_input[i].NumElements(); |
| auto data_src_ptr = StreamExecutorUtil::AsDeviceMemory<T>(params_input[i]); |
| DeviceMemoryBase data_dst_ptr = |
| SliceDeviceMemory(*data_dst, params[i].offset, size_in_bytes); |
| stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes); |
| } |
| } |
| |
| bool ShouldUsePaddedIO(const Tensor* sequence_lengths, |
| const CudnnRnnModelShapes& model_shapes, |
| bool time_major) { |
| auto seq_array = sequence_lengths->template flat<int>().data(); |
| bool all_max_seq_length = true; |
| for (int i = 0; i < model_shapes.batch_size; i++) { |
| if (seq_array[i] != model_shapes.max_seq_length) { |
| all_max_seq_length = false; |
| break; |
| } |
| } |
| return !(time_major && all_max_seq_length); |
| } |
| |
| } // namespace |
| |
| // Note: all following kernels depend on a RnnDescriptor instance, which |
| // according to Cudnn official doc should be kept around and reused across all |
| // Cudnn kernels in the same model. |
| // In Tensorflow, we don't pass the reference across different OpKernels, |
| // rather, recreate it separately in each OpKernel, which does no cause issue: |
| // CudnnDropoutDescriptor keeps a reference to a memory for |
| // random number generator state. During recreation, this state is lost. |
| // However, only forward-pass Cudnn APIs make use of the state. |
| |
| // A common base class for RNN kernels. It extracts common attributes and |
| // shape validations. |
| class CudnnRNNKernelCommon : public OpKernel { |
| protected: |
| explicit CudnnRNNKernelCommon(OpKernelConstruction* context) |
| : OpKernel(context) { |
| OP_REQUIRES_OK(context, context->GetAttr("dropout", &dropout_)); |
| OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_)); |
| OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_)); |
| string str; |
| OP_REQUIRES_OK(context, context->GetAttr("rnn_mode", &str)); |
| OP_REQUIRES_OK(context, ParseRNNMode(str, &model_types_.rnn_mode)); |
| OP_REQUIRES_OK(context, context->GetAttr("input_mode", &str)); |
| OP_REQUIRES_OK(context, |
| ParseTFRNNInputMode(str, &model_types_.rnn_input_mode)); |
| OP_REQUIRES_OK(context, context->GetAttr("direction", &str)); |
| OP_REQUIRES_OK( |
| context, ParseRNNDirectionMode(str, &model_types_.rnn_direction_mode)); |
| // Reset CudnnRnnDescriptor and related random number generate states in |
| // every Compute() call. |
| OP_REQUIRES_OK(context, ReadBoolFromEnvVar("TF_CUDNN_RESET_RND_GEN_STATE", |
| false, &reset_rnd_gen_state_)); |
| } |
| |
| bool HasInputC() const { return model_types_.HasInputC(); } |
| RnnMode rnn_mode() const { return model_types_.rnn_mode; } |
| TFRNNInputMode rnn_input_mode() const { return model_types_.rnn_input_mode; } |
| RnnDirectionMode rnn_direction_mode() const { |
| return model_types_.rnn_direction_mode; |
| } |
| const CudnnModelTypes& model_types() const { return model_types_; } |
| float dropout() const { return dropout_; } |
| uint64 seed() { return (static_cast<uint64>(seed_) << 32) | seed2_; } |
| bool ResetRndGenState() { return reset_rnd_gen_state_; } |
| |
| template <typename T> |
| Status ExtractCudnnRNNParamsInfo(OpKernelContext* context, int num_proj, |
| std::unique_ptr<RnnDescriptor>* rnn_desc) { |
| const Tensor* num_layers_t = nullptr; |
| TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t)); |
| if (!TensorShapeUtils::IsScalar(num_layers_t->shape())) { |
| return errors::InvalidArgument("num_layers is not a scalar"); |
| } |
| int num_layers = num_layers_t->scalar<int>()(); |
| const Tensor* num_units_t = nullptr; |
| TF_RETURN_IF_ERROR(context->input("num_units", &num_units_t)); |
| if (!TensorShapeUtils::IsScalar(num_units_t->shape())) { |
| return errors::InvalidArgument("num_units is not a scalar"); |
| } |
| int num_units = num_units_t->scalar<int>()(); |
| const Tensor* input_size_t = nullptr; |
| TF_RETURN_IF_ERROR(context->input("input_size", &input_size_t)); |
| if (!TensorShapeUtils::IsScalar(input_size_t->shape())) { |
| return errors::InvalidArgument("input_size is not a scalar"); |
| } |
| int input_size = input_size_t->scalar<int>()(); |
| |
| int h_num_units = (num_proj == 0 ? num_units : num_proj); |
| int c_num_units = (num_proj == 0 ? 0 : num_units); |
| |
| RnnInputMode input_mode; |
| TF_RETURN_IF_ERROR( |
| ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode)); |
| |
| Stream* stream = context->op_device_context()->stream(); |
| // ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require |
| // random number generator, therefore set state_allocator to nullptr. |
| const AlgorithmConfig algo_config; |
| auto rnn_desc_s = stream->parent()->createRnnDescriptor( |
| num_layers, h_num_units, input_size, /*cell_size=*/c_num_units, |
| /*batch_size=*/0, input_mode, rnn_direction_mode(), rnn_mode(), |
| ToDataType<T>::value, algo_config, dropout(), seed(), |
| /* state_allocator=*/nullptr, /*use_padded_io=*/false); |
| if (!rnn_desc_s.ok()) { |
| return FromExecutorStatus(rnn_desc_s); |
| } |
| *rnn_desc = rnn_desc_s.ConsumeValueOrDie(); |
| return Status::OK(); |
| } |
| |
| template <typename T> |
| Status CreateRnnDescriptor(OpKernelContext* context, |
| const CudnnRnnModelShapes& model_shapes, |
| const RnnInputMode& input_mode, |
| const AlgorithmConfig& algo_config, |
| ScratchAllocator* dropout_state_allocator, |
| std::unique_ptr<RnnDescriptor>* rnn_desc, |
| bool use_padded_io) { |
| StreamExecutor* executor = context->op_device_context()->stream()->parent(); |
| se::dnn::DataType data_type = ToDataType<T>::value; |
| auto rnn_desc_s = executor->createRnnDescriptor( |
| model_shapes.num_layers, model_shapes.num_units, |
| model_shapes.input_size, model_shapes.cell_num_units, |
| model_shapes.batch_size, input_mode, rnn_direction_mode(), rnn_mode(), |
| data_type, algo_config, dropout(), seed(), dropout_state_allocator, |
| use_padded_io); |
| TF_RETURN_IF_ERROR(rnn_desc_s.status()); |
| |
| *rnn_desc = rnn_desc_s.ConsumeValueOrDie(); |
| return Status::OK(); |
| } |
| |
| using RnnStateCache = gtl::FlatMap< |
| std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>, |
| RnnScratchSpace, CudnnRnnConfigHasher, CudnnRnnConfigComparator>; |
| // Returns a raw rnn descriptor pointer. The cache owns the rnn descriptor and |
| // should outlive the returned pointer. |
| template <typename T> |
| Status GetCachedRnnDescriptor(OpKernelContext* context, |
| const CudnnRnnModelShapes& model_shapes, |
| const RnnInputMode& input_mode, |
| const AlgorithmConfig& algo_config, |
| RnnStateCache* cache, RnnDescriptor** rnn_desc, |
| bool use_padded_io) { |
| auto key = std::make_pair(model_shapes, algo_config.algorithm()); |
| RnnScratchSpace& rnn_state = (*cache)[key]; |
| if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) { |
| CudnnRNNPersistentSpaceAllocator* dropout_state_allocator = |
| new CudnnRNNPersistentSpaceAllocator(context); |
| rnn_state.dropout_state_allocator.reset(dropout_state_allocator); |
| Status status = CreateRnnDescriptor<T>( |
| context, model_shapes, input_mode, algo_config, |
| dropout_state_allocator, &rnn_state.rnn_desc, use_padded_io); |
| TF_RETURN_IF_ERROR(status); |
| } |
| *rnn_desc = rnn_state.rnn_desc.get(); |
| return Status::OK(); |
| } |
| |
| private: |
| int seed_; |
| int seed2_; |
| float dropout_; |
| bool reset_rnd_gen_state_; |
| |
| CudnnModelTypes model_types_; |
| }; |
| |
| // A class that returns the size of the opaque parameter buffer. The user should |
| // use that to create the actual parameter buffer for training. However, it |
| // should not be used for saving and restoring. |
| template <typename T, typename Index> |
| class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon { |
| public: |
| explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context) |
| : CudnnRNNKernelCommon(context) { |
| if (context->HasAttr("num_proj")) { |
| OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_)); |
| } else { |
| num_proj_ = 0; |
| } |
| } |
| |
| void Compute(OpKernelContext* context) override { |
| std::unique_ptr<RnnDescriptor> rnn_desc; |
| OP_REQUIRES_OK(context, |
| ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc)); |
| int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); |
| CHECK(params_size_in_bytes % sizeof(T) == 0) |
| << "params_size_in_bytes must be multiple of element size"; |
| int64 params_size = params_size_in_bytes / sizeof(T); |
| |
| Tensor* output_t = nullptr; |
| OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t)); |
| *output_t->template flat<Index>().data() = params_size; |
| } |
| |
| private: |
| int num_proj_; |
| }; |
| |
| #define REGISTER_GPU(T) \ |
| REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("num_layers") \ |
| .HostMemory("num_units") \ |
| .HostMemory("input_size") \ |
| .HostMemory("params_size") \ |
| .TypeConstraint<T>("T") \ |
| .TypeConstraint<int32>("S"), \ |
| CudnnRNNParamsSizeOp<GPUDevice, T, int32>); |
| |
| TF_CALL_half(REGISTER_GPU); |
| TF_CALL_float(REGISTER_GPU); |
| TF_CALL_double(REGISTER_GPU); |
| #undef REGISTER_GPU |
| |
| // Convert weight and bias params from a platform-specific layout to the |
| // canonical form. |
| template <typename T> |
| class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon { |
| public: |
| explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context) |
| : CudnnRNNKernelCommon(context) { |
| if (context->HasAttr("num_params")) { |
| OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_)); |
| } else { |
| num_params_ = 0; |
| } |
| if (context->HasAttr("num_params_weights")) { |
| OP_REQUIRES_OK(context, context->GetAttr("num_params_weights", |
| &num_params_weights_)); |
| } else { |
| num_params_weights_ = 0; |
| } |
| if (context->HasAttr("num_params_biases")) { |
| OP_REQUIRES_OK( |
| context, context->GetAttr("num_params_biases", &num_params_biases_)); |
| } else { |
| num_params_biases_ = 0; |
| } |
| if (context->HasAttr("num_proj")) { |
| OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_)); |
| } else { |
| num_proj_ = 0; |
| } |
| if (num_proj_ == 0) { |
| num_params_weights_ = num_params_; |
| num_params_biases_ = num_params_; |
| } |
| } |
| |
| void Compute(OpKernelContext* context) override { |
| const Tensor& input = context->input(3); |
| auto input_ptr = StreamExecutorUtil::AsDeviceMemory<T>(input); |
| Stream* stream = context->op_device_context()->stream(); |
| |
| std::unique_ptr<RnnDescriptor> rnn_desc; |
| OP_REQUIRES_OK(context, |
| ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc)); |
| int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); |
| CHECK(params_size_in_bytes % sizeof(T) == 0) |
| << "params_size_in_bytes must be multiple of element size"; |
| |
| const Tensor* num_units_t = nullptr; |
| OP_REQUIRES_OK(context, context->input("num_units", &num_units_t)); |
| CHECK(TensorShapeUtils::IsScalar(num_units_t->shape())) |
| << "num_units is not a scalar"; |
| int num_units = num_units_t->scalar<int>()(); |
| |
| const Tensor* input_size_t = nullptr; |
| OP_REQUIRES_OK(context, context->input("input_size", &input_size_t)); |
| CHECK(TensorShapeUtils::IsScalar(input_size_t->shape())) |
| << "input_size is not a scalar"; |
| int input_size = input_size_t->scalar<int>()(); |
| |
| const Tensor* num_layers_t = nullptr; |
| OP_REQUIRES_OK(context, context->input("num_layers", &num_layers_t)); |
| CHECK(TensorShapeUtils::IsScalar(num_layers_t->shape())) |
| << "num_layers is not a scalar"; |
| int num_layers = num_layers_t->scalar<int>()(); |
| int num_dirs = 1; |
| if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) { |
| num_dirs = 2; |
| } |
| const int num_params_weights_per_layer = |
| num_params_weights_ / num_layers / num_dirs; |
| // Number of params applied on inputs. The rest are applied on recurrent |
| // hidden states. |
| const int num_params_input_state = num_params_weights_per_layer / 2; |
| OP_REQUIRES( |
| context, num_params_weights_ % (num_layers * num_dirs) == 0, |
| errors::InvalidArgument("Number of params (weights) is not a multiple" |
| "of num_layers * num_dirs.")); |
| OP_REQUIRES( |
| context, num_params_biases_ % (num_layers * num_dirs) == 0, |
| errors::InvalidArgument("Number of params (biases) is not a multiple" |
| "of num_layers * num_dirs.")); |
| if (num_proj_ == 0) { |
| OP_REQUIRES( |
| context, num_params_weights_per_layer % 2 == 0, |
| errors::InvalidArgument("Number of params (weights) per layer is not" |
| "an even number with no projection.")); |
| } else { |
| OP_REQUIRES( |
| context, num_params_weights_per_layer % 2 != 0, |
| errors::InvalidArgument("Number of params (weights) per layer is not" |
| "an odl number with projection.")); |
| } |
| |
| OP_REQUIRES( |
| context, num_params_weights_ == rnn_desc->ParamsWeightRegions().size(), |
| errors::InvalidArgument("C Number of params mismatch. Expected ", |
| num_params_weights_, ", got ", |
| rnn_desc->ParamsWeightRegions().size())); |
| int h_num_units = (num_proj_ == 0 ? num_units : num_proj_); |
| int c_num_units = (num_proj_ == 0 ? 0 : num_units); |
| for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) { |
| int64 size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size; |
| int64 size = size_in_bytes / sizeof(T); |
| const int layer_idx = i / num_params_weights_per_layer; |
| const int index_within_layer = i % num_params_weights_per_layer; |
| int width = 0, height = (num_proj_ == 0 ? h_num_units : c_num_units); |
| // In CuDNN layout, each layer has num_params_weights_per_layer params, |
| // with the |
| // first half a.k.a num_params_input_state params applied on the inputs, |
| // and the second half on the recurrent hidden states. |
| bool apply_on_input_state = index_within_layer < num_params_input_state; |
| if (rnn_direction_mode() == RnnDirectionMode::kRnnUnidirectional) { |
| if (layer_idx == 0 && apply_on_input_state) { |
| width = input_size; |
| } else { |
| width = h_num_units; |
| } |
| } else { |
| if (apply_on_input_state) { |
| if (layer_idx <= 1) { |
| // First fwd or bak layer. |
| width = input_size; |
| } else { |
| // Following layers, cell inputs are concatenated outputs of |
| // its prior layer. |
| width = 2 * h_num_units; |
| } |
| } else { |
| width = h_num_units; |
| } |
| } |
| CHECK(size == width * height) << "Params size mismatch. Expected " |
| << width * height << ", got " << size; |
| Tensor* output = nullptr; |
| int id_in_layer = i % num_params_weights_per_layer; |
| if (num_proj_ != 0 && id_in_layer == num_params_weights_per_layer - 1) { |
| std::swap(height, width); |
| } |
| OP_REQUIRES_OK(context, context->allocate_output( |
| i, TensorShape({height, width}), &output)); |
| DeviceMemoryBase data_src_ptr = SliceDeviceMemory( |
| input_ptr, rnn_desc->ParamsWeightRegions()[i].offset, size_in_bytes); |
| auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output); |
| stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes); |
| } |
| |
| OP_REQUIRES( |
| context, num_params_biases_ == rnn_desc->ParamsBiasRegions().size(), |
| errors::InvalidArgument("A Number of params mismatch. Expected ", |
| num_params_biases_, ", got ", |
| rnn_desc->ParamsBiasRegions().size())); |
| for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) { |
| int64 size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size; |
| int64 size = size_in_bytes / sizeof(T); |
| OP_REQUIRES(context, size == num_units, |
| errors::InvalidArgument("Params size mismatch. Expected ", |
| num_units, ", got ", size)); |
| |
| Tensor* output = nullptr; |
| OP_REQUIRES_OK(context, |
| context->allocate_output(num_params_weights_ + i, |
| TensorShape({size}), &output)); |
| DeviceMemoryBase data_src_ptr = SliceDeviceMemory( |
| input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes); |
| auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output); |
| stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes); |
| } |
| } |
| |
| private: |
| int num_params_; |
| int num_params_weights_; |
| int num_params_biases_; |
| int num_proj_; |
| }; |
| |
| #define REGISTER_GPU(T) \ |
| REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonical") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("num_layers") \ |
| .HostMemory("num_units") \ |
| .HostMemory("input_size") \ |
| .TypeConstraint<T>("T"), \ |
| CudnnRNNParamsToCanonical<GPUDevice, T>); |
| TF_CALL_half(REGISTER_GPU); |
| TF_CALL_float(REGISTER_GPU); |
| TF_CALL_double(REGISTER_GPU); |
| #undef REGISTER_GPU |
| |
| #define REGISTER_GPU(T) \ |
| REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonicalV2") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("num_layers") \ |
| .HostMemory("num_units") \ |
| .HostMemory("input_size") \ |
| .TypeConstraint<T>("T"), \ |
| CudnnRNNParamsToCanonical<GPUDevice, T>); |
| TF_CALL_half(REGISTER_GPU); |
| TF_CALL_float(REGISTER_GPU); |
| TF_CALL_double(REGISTER_GPU); |
| #undef REGISTER_GPU |
| |
| // Convert weight and bias params from the canonical form to a |
| // platform-specific layout. |
| template <typename T> |
| class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon { |
| public: |
| explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context) |
| : CudnnRNNKernelCommon(context) { |
| if (context->HasAttr("num_proj")) { |
| OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_)); |
| } else { |
| num_proj_ = 0; |
| } |
| } |
| |
| void Compute(OpKernelContext* context) override { |
| std::unique_ptr<RnnDescriptor> rnn_desc; |
| OP_REQUIRES_OK(context, |
| ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc)); |
| int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); |
| CHECK(params_size_in_bytes % sizeof(T) == 0) |
| << "params_size_in_bytes must be multiple of element size"; |
| Tensor* output = nullptr; |
| int params_size = params_size_in_bytes / sizeof(T); |
| OP_REQUIRES_OK(context, |
| context->allocate_output(0, {params_size}, &output)); |
| auto output_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output); |
| Stream* stream = context->op_device_context()->stream(); |
| |
| OpInputList weights; |
| OP_REQUIRES_OK(context, context->input_list("weights", &weights)); |
| RestoreParams<T>(weights, rnn_desc->ParamsWeightRegions(), &output_ptr, |
| stream); |
| |
| OpInputList biases; |
| OP_REQUIRES_OK(context, context->input_list("biases", &biases)); |
| RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr, |
| stream); |
| } |
| |
| private: |
| int num_proj_; |
| }; |
| |
| #define REGISTER_GPU(T) \ |
| REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParams") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("num_layers") \ |
| .HostMemory("num_units") \ |
| .HostMemory("input_size") \ |
| .TypeConstraint<T>("T"), \ |
| CudnnRNNCanonicalToParams<GPUDevice, T>); |
| TF_CALL_half(REGISTER_GPU); |
| TF_CALL_float(REGISTER_GPU); |
| TF_CALL_double(REGISTER_GPU); |
| #undef REGISTER_GPU |
| |
| #define REGISTER_GPU(T) \ |
| REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParamsV2") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("num_layers") \ |
| .HostMemory("num_units") \ |
| .HostMemory("input_size") \ |
| .TypeConstraint<T>("T"), \ |
| CudnnRNNCanonicalToParams<GPUDevice, T>); |
| TF_CALL_half(REGISTER_GPU); |
| TF_CALL_float(REGISTER_GPU); |
| TF_CALL_double(REGISTER_GPU); |
| #undef REGISTER_GPU |
| |
| // Run the forward operation of the RNN model. |
| template <typename T> |
| class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { |
| public: |
| explicit CudnnRNNForwardOp(OpKernelConstruction* context) |
| : CudnnRNNKernelCommon(context) { |
| OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); |
| |
| // Read debug env variables. |
| is_debug_mode_ = DebugCudnnRnn(); |
| debug_cudnn_rnn_algo_ = DebugCudnnRnnAlgo(); |
| debug_use_tensor_ops_ = DebugCudnnRnnUseTensorOps(); |
| } |
| |
| void Compute(OpKernelContext* context) override { |
| AlgorithmConfig algo_config; |
| ComputeAndReturnAlgorithm(context, &algo_config, /*var_seq_lengths=*/false, |
| /*time_major=*/true, /*num_proj=*/0); |
| } |
| |
| protected: |
| virtual void ComputeAndReturnAlgorithm(OpKernelContext* context, |
| AlgorithmConfig* output_algo_config, |
| bool var_seq_lengths, bool time_major, |
| int num_proj) { |
| CHECK_NE(output_algo_config, nullptr); |
| |
| const Tensor* input = nullptr; |
| const Tensor* input_h = nullptr; |
| const Tensor* input_c = nullptr; |
| const Tensor* params = nullptr; |
| const Tensor* sequence_lengths = nullptr; |
| CudnnRnnModelShapes model_shapes; |
| bool use_padded_io = false; |
| if (var_seq_lengths) { |
| OP_REQUIRES_OK(context, ExtractForwardInput( |
| context, model_types(), time_major, &input, |
| &input_h, &input_c, ¶ms, |
| &sequence_lengths, num_proj, &model_shapes)); |
| use_padded_io = |
| ShouldUsePaddedIO(sequence_lengths, model_shapes, time_major); |
| } else { |
| OP_REQUIRES_OK(context, |
| ExtractForwardInput(context, model_types(), time_major, |
| &input, &input_h, &input_c, ¶ms, |
| num_proj, &model_shapes)); |
| } |
| RnnInputMode input_mode; |
| OP_REQUIRES_OK(context, |
| ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, |
| model_shapes.input_size, &input_mode)); |
| |
| Tensor* output = nullptr; |
| Tensor* output_h = nullptr; |
| Tensor* output_c = nullptr; |
| OP_REQUIRES_OK(context, AllocateOutputs(context, model_shapes, &output, |
| &output_h, &output_c)); |
| |
| // Creates a memory callback for the reserve_space. The memory lives in the |
| // output of this kernel. And it will be fed into the backward pass when |
| // needed. |
| CudnnRnnAllocatorInOutput<T> reserve_space_allocator(context, 3); |
| // Creates a memory callback for the workspace. The memory lives to the end |
| // of this kernel calls. |
| CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context); |
| |
| if (is_debug_mode_) { |
| AlgorithmDesc algo_desc(debug_cudnn_rnn_algo_, debug_use_tensor_ops_); |
| output_algo_config->set_algorithm(algo_desc); |
| } else { |
| OP_REQUIRES_OK(context, |
| MaybeAutoTune(context, model_shapes, input_mode, input, |
| input_h, input_c, params, output, output_h, |
| output_c, output_algo_config)); |
| } |
| |
| Status launch_status; |
| { |
| mutex_lock l(mu_); |
| RnnDescriptor* rnn_desc_ptr = nullptr; |
| OP_REQUIRES_OK(context, |
| GetCachedRnnDescriptor<T>( |
| context, model_shapes, input_mode, *output_algo_config, |
| &rnn_state_cache_, &rnn_desc_ptr, use_padded_io)); |
| launch_status = DoForward<T>( |
| context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h, |
| input_c, params, is_training_, output, output_h, output_c, |
| sequence_lengths, time_major, &reserve_space_allocator, |
| &workspace_allocator, /*output_profile_result=*/nullptr); |
| } |
| OP_REQUIRES_OK(context, launch_status); |
| } |
| |
| protected: |
| virtual Status MaybeAutoTune(OpKernelContext* context, |
| const CudnnRnnModelShapes& model_shapes, |
| const RnnInputMode& input_mode, |
| const Tensor* input, const Tensor* input_h, |
| const Tensor* input_c, const Tensor* params, |
| Tensor* output, Tensor* output_h, |
| Tensor* output_c, |
| AlgorithmConfig* best_algo_config) { |
| CHECK_NE(best_algo_config, nullptr); |
| *best_algo_config = AlgorithmConfig(); |
| return Status::OK(); |
| } |
| |
| bool is_training() const { return is_training_; } |
| bool is_debug_mode_; |
| bool debug_use_tensor_ops_; |
| int64 debug_cudnn_rnn_algo_; |
| |
| private: |
| Status AllocateOutputs(OpKernelContext* context, |
| const CudnnRnnModelShapes& model_shapes, |
| Tensor** output, Tensor** output_h, |
| Tensor** output_c) { |
| const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; |
| const TensorShape& output_shape = model_shapes.output_shape; |
| const TensorShape& cell_state_shape = model_shapes.cell_state_shape; |
| |
| TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, output)); |
| TF_RETURN_IF_ERROR( |
| context->allocate_output(1, hidden_state_shape, output_h)); |
| if (HasInputC()) { |
| TF_RETURN_IF_ERROR( |
| context->allocate_output(2, cell_state_shape, output_c)); |
| } else { |
| // Only LSTM uses input_c and output_c. So for all other models, we only |
| // need to create dummy outputs. |
| TF_RETURN_IF_ERROR(context->allocate_output(2, {}, output_c)); |
| } |
| if (!is_training_) { |
| Tensor* dummy_reserve_space = nullptr; |
| TF_RETURN_IF_ERROR(context->allocate_output(3, {}, &dummy_reserve_space)); |
| } |
| return Status::OK(); |
| } |
| |
| mutex mu_; |
| bool is_training_; |
| RnnStateCache rnn_state_cache_ GUARDED_BY(mu_); |
| }; |
| |
| #define REGISTER_GPU(T) \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ |
| CudnnRNNForwardOp<GPUDevice, T>); |
| |
| TF_CALL_half(REGISTER_GPU); |
| TF_CALL_float(REGISTER_GPU); |
| TF_CALL_double(REGISTER_GPU); |
| #undef REGISTER_GPU |
| |
| template <typename T> |
| class CudnnRNNForwardOpV2<GPUDevice, T> |
| : public CudnnRNNForwardOp<GPUDevice, T> { |
| private: |
| using CudnnRNNForwardOp<GPUDevice, T>::is_training; |
| using CudnnRNNKernelCommon::CreateRnnDescriptor; |
| using CudnnRNNKernelCommon::dropout; |
| using CudnnRNNKernelCommon::HasInputC; |
| using CudnnRNNKernelCommon::model_types; |
| |
| public: |
| explicit CudnnRNNForwardOpV2(OpKernelConstruction* context) |
| : CudnnRNNForwardOp<GPUDevice, T>(context) {} |
| |
| void Compute(OpKernelContext* context) override { |
| AlgorithmConfig best_algo_config; |
| CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm( |
| context, &best_algo_config, /*var_seq_lengths=*/false, |
| /*time_major=*/true, /*num_proj=*/0); |
| if (!context->status().ok()) { |
| return; |
| } |
| |
| Tensor* output_host_reserved = nullptr; |
| // output_host_reserved stores opaque info used for backprop when running |
| // in training mode. At present, it includes a serialization of the best |
| // AlgorithmDesc picked during rnn forward pass autotune. |
| // int8 algorithm_id |
| // int8 use_tensor_op |
| // If autotune is not enabled, the algorithm_id is |
| // stream_executor::dnn::kDefaultAlgorithm and use_tensor_op is false. If |
| // running in inference mode, the output_host_reserved is currently not |
| // populated. |
| if (is_training()) { |
| OP_REQUIRES_OK(context, context->allocate_output(4, TensorShape({2}), |
| &output_host_reserved)); |
| auto output_host_reserved_int8 = output_host_reserved->vec<int8>(); |
| output_host_reserved_int8(0) = best_algo_config.algorithm()->algo_id(); |
| output_host_reserved_int8(1) = |
| best_algo_config.algorithm()->tensor_ops_enabled(); |
| } else { |
| OP_REQUIRES_OK(context, |
| context->allocate_output(4, {}, &output_host_reserved)); |
| } |
| } |
| |
| protected: |
| Status MaybeAutoTune(OpKernelContext* context, |
| const CudnnRnnModelShapes& model_shapes, |
| const RnnInputMode& input_mode, const Tensor* input, |
| const Tensor* input_h, const Tensor* input_c, |
| const Tensor* params, Tensor* output, Tensor* output_h, |
| Tensor* output_c, |
| AlgorithmConfig* algo_config) override { |
| CHECK_NE(algo_config, nullptr); |
| if (!CudnnRnnUseAutotune() || this->is_debug_mode_) { |
| *algo_config = AlgorithmConfig(); |
| return Status::OK(); |
| } |
| |
| std::vector<AlgorithmDesc> algorithms; |
| auto* stream = context->op_device_context()->stream(); |
| CHECK(stream->parent()->GetRnnAlgorithms(&algorithms)); |
| if (algorithms.empty()) { |
| LOG(WARNING) << "No Rnn algorithm found"; |
| return Status::OK(); |
| } |
| |
| const auto& modeltypes = model_types(); |
| CudnnRnnParameters rnn_params( |
| model_shapes.num_layers, model_shapes.input_size, |
| model_shapes.num_units, model_shapes.max_seq_length, |
| model_shapes.batch_size, model_shapes.dir_count, |
| /*has_dropout=*/std::abs(dropout()) > 1e-8, is_training(), |
| modeltypes.rnn_mode, modeltypes.rnn_input_mode, input->dtype()); |
| |
| if (AutoTuneRnnConfigMap::GetInstance()->Find(rnn_params, algo_config)) { |
| VLOG(1) << "Using existing best Cudnn RNN algorithm " |
| << "(algo, tensor_op_enabled) = (" |
| << algo_config->algorithm()->algo_id() << ", " |
| << algo_config->algorithm()->tensor_ops_enabled() << ")."; |
| return Status::OK(); |
| } |
| |
| // Create temp tensors when profiling backprop pass. |
| auto data_type = input->dtype(); |
| Tensor output_backprop; |
| Tensor output_h_backprop; |
| Tensor output_c_backprop; |
| Tensor input_backprop; |
| Tensor input_h_backprop; |
| Tensor input_c_backprop; |
| Tensor params_backprop; |
| if (is_training()) { |
| TF_RETURN_IF_ERROR(context->allocate_temp( |
| data_type, model_shapes.output_shape, &output_backprop)); |
| TF_RETURN_IF_ERROR(context->allocate_temp( |
| data_type, model_shapes.hidden_state_shape, &output_h_backprop)); |
| |
| TF_RETURN_IF_ERROR( |
| context->allocate_temp(data_type, params->shape(), ¶ms_backprop)); |
| TF_RETURN_IF_ERROR(context->allocate_temp( |
| data_type, model_shapes.input_shape, &input_backprop)); |
| TF_RETURN_IF_ERROR(context->allocate_temp( |
| data_type, model_shapes.hidden_state_shape, &input_h_backprop)); |
| if (HasInputC()) { |
| TF_RETURN_IF_ERROR(context->allocate_temp( |
| data_type, model_shapes.hidden_state_shape, &output_c_backprop)); |
| TF_RETURN_IF_ERROR(context->allocate_temp( |
| data_type, model_shapes.hidden_state_shape, &input_c_backprop)); |
| } |
| } |
| ProfileResult best_result; |
| for (auto& algo : algorithms) { |
| VLOG(1) << "Profile Cudnn RNN algorithm (algo, tensor_op_enabled) = (" |
| << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ")."; |
| Status status; |
| ProfileResult final_profile_result; |
| |
| ProfileResult fwd_profile_result; |
| ProfileResult bak_profile_result; |
| |
| // RnnDescriptor is algorithm-dependent, thus not reusable. |
| std::unique_ptr<RnnDescriptor> rnn_desc; |
| // Use a temp scratch allocator for the random num generator. |
| CudnnRnnAllocatorInTemp<uint8> dropout_state_allocator(context); |
| if (!this->template CreateRnnDescriptor<T>( |
| context, model_shapes, input_mode, AlgorithmConfig(algo), |
| &dropout_state_allocator, &rnn_desc, |
| /*use_padded_io=*/false) |
| .ok()) { |
| continue; |
| } |
| |
| // Again use temp scratch allocator during profiling. |
| CudnnRnnAllocatorInTemp<T> reserve_space_allocator(context); |
| CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context); |
| status = DoForward<T>(context, *rnn_desc, model_types(), model_shapes, |
| input, input_h, input_c, params, is_training(), |
| output, output_h, output_c, nullptr, true, |
| &reserve_space_allocator, &workspace_allocator, |
| &fwd_profile_result); |
| if (!status.ok()) { |
| continue; |
| } |
| |
| if (is_training()) { |
| // Get reserve space from the forward pass. |
| Tensor reserve_space = reserve_space_allocator.get_allocated_tensor(0); |
| status = DoBackward<T>( |
| context, *rnn_desc, model_types(), model_shapes, input, input_h, |
| input_c, params, output, output_h, output_c, &output_backprop, |
| &output_h_backprop, &output_c_backprop, &reserve_space, |
| &input_backprop, &input_h_backprop, &input_c_backprop, |
| ¶ms_backprop, nullptr, true, &workspace_allocator, |
| &bak_profile_result); |
| if (!status.ok()) { |
| continue; |
| } |
| final_profile_result.set_elapsed_time_in_ms( |
| fwd_profile_result.elapsed_time_in_ms() + |
| bak_profile_result.elapsed_time_in_ms()); |
| } else { |
| final_profile_result = fwd_profile_result; |
| } |
| |
| auto total_time = final_profile_result.elapsed_time_in_ms(); |
| VLOG(1) << "Cudnn RNN algorithm (algo, tensor_op_enabled) = (" |
| << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ")" |
| << " run time: " << total_time << " ms."; |
| if (total_time < best_result.elapsed_time_in_ms()) { |
| best_result.set_elapsed_time_in_ms(total_time); |
| best_result.set_algorithm(algo); |
| } |
| } |
| |
| if (!best_result.is_valid()) { |
| return Status(error::Code::INTERNAL, "No algorithm worked!"); |
| } |
| algo_config->set_algorithm(best_result.algorithm()); |
| VLOG(1) << "Best Cudnn RNN algorithm (algo, tensor_op_enabled) = (" |
| << best_result.algorithm().algo_id() << ", " |
| << best_result.algorithm().tensor_ops_enabled() << ")."; |
| AutoTuneRnnConfigMap::GetInstance()->Insert(rnn_params, *algo_config); |
| return Status::OK(); |
| } |
| }; |
| |
| #define REGISTER_GPU(T) \ |
| REGISTER_KERNEL_BUILDER(Name("CudnnRNNV2") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("host_reserved") \ |
| .TypeConstraint<T>("T"), \ |
| CudnnRNNForwardOpV2<GPUDevice, T>); |
| |
| TF_CALL_half(REGISTER_GPU); |
| TF_CALL_float(REGISTER_GPU); |
| TF_CALL_double(REGISTER_GPU); |
| #undef REGISTER_GPU |
| |
| template <typename T> |
| class CudnnRNNForwardOpV3<GPUDevice, T> |
| : public CudnnRNNForwardOp<GPUDevice, T> { |
| private: |
| using CudnnRNNForwardOp<GPUDevice, T>::is_training; |
| using CudnnRNNKernelCommon::CreateRnnDescriptor; |
| using CudnnRNNKernelCommon::dropout; |
| using CudnnRNNKernelCommon::HasInputC; |
| using CudnnRNNKernelCommon::model_types; |
| bool time_major_; |
| |
| protected: |
| bool time_major() { return time_major_; } |
| |
| public: |
| explicit CudnnRNNForwardOpV3(OpKernelConstruction* context) |
| : CudnnRNNForwardOp<GPUDevice, T>(context) { |
| OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_)); |
| if (context->HasAttr("num_proj")) { |
| OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_)); |
| } else { |
| num_proj_ = 0; |
| } |
| } |
| |
| void Compute(OpKernelContext* context) override { |
| AlgorithmConfig best_algo_config; |
| CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm( |
| context, &best_algo_config, /*var_seq_lengths=*/true, |
| /*time_major=*/time_major(), num_proj_); |
| if (!context->status().ok()) { |
| return; |
| } |
| |
| Tensor* output_host_reserved = nullptr; |
| // TODO: Current V3 only uses the default standard algorithm to process |
| // batches with variable sequences and the inputs should be padded. |
| // Autotune is not supported yet. |
| OP_REQUIRES_OK(context, |
| context->allocate_output(4, {}, &output_host_reserved)); |
| } |
| |
| private: |
| int num_proj_; |
| }; |
| |
| #define REGISTER_GPU(T) \ |
| REGISTER_KERNEL_BUILDER(Name("CudnnRNNV3") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("sequence_lengths") \ |
| .HostMemory("host_reserved") \ |
| .TypeConstraint<T>("T"), \ |
| CudnnRNNForwardOpV3<GPUDevice, T>); |
| |
| TF_CALL_half(REGISTER_GPU); |
| TF_CALL_float(REGISTER_GPU); |
| TF_CALL_double(REGISTER_GPU); |
| #undef REGISTER_GPU |
| |
| // Run the backward operation of the RNN model. |
| template <typename T> |
| class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { |
| public: |
| explicit CudnnRNNBackwardOp(OpKernelConstruction* context) |
| : CudnnRNNKernelCommon(context) {} |
| |
| void Compute(OpKernelContext* context) override { |
| ComputeImpl(context, false, true, 0); |
| } |
| |
| protected: |
| virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths, |
| bool time_major, int num_proj) { |
| const Tensor* input = nullptr; |
| const Tensor* input_h = nullptr; |
| const Tensor* input_c = nullptr; |
| const Tensor* params = nullptr; |
| const Tensor* sequence_lengths = nullptr; |
| CudnnRnnModelShapes model_shapes; |
| bool use_padded_io = false; |
| if (var_seq_lengths) { |
| OP_REQUIRES_OK(context, ExtractForwardInput( |
| context, model_types(), time_major, &input, |
| &input_h, &input_c, ¶ms, |
| &sequence_lengths, num_proj, &model_shapes)); |
| use_padded_io = |
| ShouldUsePaddedIO(sequence_lengths, model_shapes, time_major); |
| } else { |
| OP_REQUIRES_OK(context, |
| ExtractForwardInput(context, model_types(), time_major, |
| &input, &input_h, &input_c, ¶ms, |
| num_proj, &model_shapes)); |
| } |
| RnnInputMode input_mode; |
| OP_REQUIRES_OK(context, |
| ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, |
| model_shapes.input_size, &input_mode)); |
| |
| const Tensor* output = nullptr; |
| const Tensor* output_h = nullptr; |
| const Tensor* output_c = nullptr; |
| const Tensor* output_backprop = nullptr; |
| const Tensor* output_h_backprop = nullptr; |
| const Tensor* output_c_backprop = nullptr; |
| const Tensor* reserve_space = nullptr; |
| OP_REQUIRES_OK(context, |
| ExtractBackwardInputs(context, model_shapes, model_types(), |
| &output, &output_h, &output_c, |
| &output_backprop, &output_h_backprop, |
| &output_c_backprop, &reserve_space)); |
| |
| Tensor* input_backprop = nullptr; |
| Tensor* input_h_backprop = nullptr; |
| Tensor* input_c_backprop = nullptr; |
| Tensor* params_backprop = nullptr; |
| OP_REQUIRES_OK(context, |
| AllocateOutputs(context, model_shapes, params->shape(), |
| &input_backprop, &input_h_backprop, |
| &input_c_backprop, ¶ms_backprop)); |
| |
| // Creates a memory callback for the workspace. The memory lives to the end |
| // of this kernel calls. |
| CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context); |
| AlgorithmConfig algo_config; |
| OP_REQUIRES_OK(context, GetAlgorithm(context, &algo_config)); |
| Status launch_status; |
| { |
| mutex_lock l(mu_); |
| RnnDescriptor* rnn_desc_ptr = nullptr; |
| OP_REQUIRES_OK( |
| context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode, |
| algo_config, &rnn_state_cache_, |
| &rnn_desc_ptr, use_padded_io)); |
| launch_status = DoBackward<T>( |
| context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h, |
| input_c, params, output, output_h, output_c, output_backprop, |
| output_h_backprop, output_c_backprop, reserve_space, input_backprop, |
| input_h_backprop, input_c_backprop, params_backprop, sequence_lengths, |
| time_major, &workspace_allocator, |
| /*output_profile_result=*/nullptr); |
| } |
| OP_REQUIRES_OK(context, launch_status); |
| } |
| |
| protected: |
| virtual Status GetAlgorithm(OpKernelContext* context, |
| AlgorithmConfig* algo_config) { |
| CHECK_NE(algo_config, nullptr); |
| *algo_config = AlgorithmConfig(); |
| return Status::OK(); |
| } |
| |
| private: |
| mutex mu_; |
| RnnStateCache rnn_state_cache_ GUARDED_BY(mu_); |
| |
| Status ExtractBackwardInputs( |
| OpKernelContext* context, const CudnnRnnModelShapes& model_shapes, |
| const CudnnModelTypes& model_types, const Tensor** output, |
| const Tensor** output_h, const Tensor** output_c, |
| const Tensor** output_backprop, const Tensor** output_h_backprop, |
| const Tensor** output_c_backprop, const Tensor** reserve_space) { |
| TF_RETURN_IF_ERROR(context->input("output", output)); |
| TF_RETURN_IF_ERROR(context->input("output_backprop", output_backprop)); |
| TF_RETURN_IF_ERROR(context->input("output_h", output_h)); |
| TF_RETURN_IF_ERROR(context->input("output_h_backprop", output_h_backprop)); |
| if (model_types.HasInputC()) { |
| TF_RETURN_IF_ERROR(context->input("output_c", output_c)); |
| TF_RETURN_IF_ERROR( |
| context->input("output_c_backprop", output_c_backprop)); |
| } |
| TF_RETURN_IF_ERROR(context->input("reserve_space", reserve_space)); |
| const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; |
| const TensorShape& output_shape = model_shapes.output_shape; |
| const TensorShape& cell_state_shape = model_shapes.cell_state_shape; |
| |
| if (output_shape != (*output)->shape()) { |
| return errors::InvalidArgument( |
| "Invalid output shape: ", (*output)->shape().DebugString(), " ", |
| output_shape.DebugString()); |
| } |
| if (hidden_state_shape != (*output_h)->shape()) { |
| return errors::InvalidArgument( |
| "Invalid output_h shape: ", (*output_h)->shape().DebugString(), " ", |
| hidden_state_shape.DebugString()); |
| } |
| |
| if (output_shape != (*output_backprop)->shape()) { |
| return errors::InvalidArgument("Invalid output_backprop shape: ", |
| (*output_backprop)->shape().DebugString(), |
| " ", output_shape.DebugString()); |
| } |
| if (hidden_state_shape != (*output_h_backprop)->shape()) { |
| return errors::InvalidArgument( |
| "Invalid output_h_backprop shape: ", |
| (*output_h_backprop)->shape().DebugString(), " ", |
| hidden_state_shape.DebugString()); |
| } |
| |
| if (model_types.HasInputC()) { |
| if (cell_state_shape != (*output_c)->shape()) { |
| return errors::InvalidArgument( |
| "Invalid output_c shape: ", (*output_c)->shape().DebugString(), " ", |
| cell_state_shape.DebugString()); |
| } |
| if (cell_state_shape != (*output_c_backprop)->shape()) { |
| return errors::InvalidArgument( |
| "Invalid output_c_backprop shape: ", |
| (*output_c_backprop)->shape().DebugString(), " ", |
| cell_state_shape.DebugString()); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status AllocateOutputs(OpKernelContext* context, |
| const CudnnRnnModelShapes& model_shapes, |
| const TensorShape& params_shape, |
| Tensor** input_backprop, Tensor** input_h_backprop, |
| Tensor** input_c_backprop, Tensor** params_backprop) { |
| const TensorShape& input_shape = model_shapes.input_shape; |
| const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; |
| const TensorShape& cell_state_shape = model_shapes.cell_state_shape; |
| |
| TF_RETURN_IF_ERROR( |
| context->allocate_output(0, input_shape, input_backprop)); |
| TF_RETURN_IF_ERROR( |
| context->allocate_output(1, hidden_state_shape, input_h_backprop)); |
| if (HasInputC()) { |
| TF_RETURN_IF_ERROR( |
| context->allocate_output(2, cell_state_shape, input_c_backprop)); |
| } else { |
| // Only LSTM uses input_c and output_c. So for all other models, we only |
| // need to create dummy outputs. |
| TF_RETURN_IF_ERROR(context->allocate_output(2, {}, input_c_backprop)); |
| } |
| TF_RETURN_IF_ERROR( |
| context->allocate_output(3, params_shape, params_backprop)); |
| return Status::OK(); |
| } |
| }; |
| |
| #define REGISTER_GPU(T) \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ |
| CudnnRNNBackwardOp<GPUDevice, T>); |
| |
| TF_CALL_half(REGISTER_GPU); |
| TF_CALL_float(REGISTER_GPU); |
| TF_CALL_double(REGISTER_GPU); |
| #undef REGISTER_GPU |
| |
| template <typename T> |
| class CudnnRNNBackwardOpV2<GPUDevice, T> |
| : public CudnnRNNBackwardOp<GPUDevice, T> { |
| public: |
| explicit CudnnRNNBackwardOpV2(OpKernelConstruction* context) |
| : CudnnRNNBackwardOp<GPUDevice, T>(context) {} |
| |
| protected: |
| Status GetAlgorithm(OpKernelContext* context, |
| AlgorithmConfig* algo_config) override { |
| CHECK_NE(algo_config, nullptr); |
| const Tensor* host_reserved = nullptr; |
| TF_RETURN_IF_ERROR(context->input("host_reserved", &host_reserved)); |
| |
| auto host_reserved_int8 = host_reserved->vec<int8>(); |
| const AlgorithmDesc algo_desc(host_reserved_int8(0), host_reserved_int8(1)); |
| algo_config->set_algorithm(algo_desc); |
| return Status::OK(); |
| } |
| }; |
| |
| #define REGISTER_GPU(T) \ |
| REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV2") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("host_reserved") \ |
| .TypeConstraint<T>("T"), \ |
| CudnnRNNBackwardOpV2<GPUDevice, T>); |
| |
| TF_CALL_half(REGISTER_GPU); |
| TF_CALL_float(REGISTER_GPU); |
| TF_CALL_double(REGISTER_GPU); |
| #undef REGISTER_GPU |
| |
| template <typename T> |
| class CudnnRNNBackwardOpV3<GPUDevice, T> |
| : public CudnnRNNBackwardOp<GPUDevice, T> { |
| private: |
| bool time_major_; |
| |
| protected: |
| bool time_major() { return time_major_; } |
| |
| public: |
| explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context) |
| : CudnnRNNBackwardOp<GPUDevice, T>(context) { |
| OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_)); |
| if (context->HasAttr("num_proj")) { |
| OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_)); |
| } else { |
| num_proj_ = 0; |
| } |
| } |
| |
| void Compute(OpKernelContext* context) override { |
| CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major(), |
| num_proj_); |
| } |
| |
| private: |
| int num_proj_; |
| }; |
| |
| #define REGISTER_GPU(T) \ |
| REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV3") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("sequence_lengths") \ |
| .HostMemory("host_reserved") \ |
| .TypeConstraint<T>("T"), \ |
| CudnnRNNBackwardOpV3<GPUDevice, T>); |
| |
| TF_CALL_half(REGISTER_GPU); |
| TF_CALL_float(REGISTER_GPU); |
| TF_CALL_double(REGISTER_GPU); |
| #undef REGISTER_GPU |
| |
| // TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to |
| // its canonical form. |
| |
| #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| } // namespace tensorflow |