| /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // The Stream is used in conjunction with the StreamExecutor "parent" to |
| // perform actions with a linear stream of dependencies. Dependencies can also |
| // be created between Streams to do task management (i.e. limit which tasks |
| // can be performed concurrently and specify what task dependencies exist). |
| |
| #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_H_ |
| #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_H_ |
| |
| #include <complex> |
| #include <functional> |
| #include <memory> |
| #include <type_traits> |
| |
| #include "absl/base/thread_annotations.h" |
| #include "absl/synchronization/mutex.h" |
| #include "absl/types/span.h" |
| #include "tensorflow/compiler/xla/stream_executor/blas.h" |
| #include "tensorflow/compiler/xla/stream_executor/device_memory.h" |
| #include "tensorflow/compiler/xla/stream_executor/dnn.h" |
| #include "tensorflow/compiler/xla/stream_executor/event.h" |
| #include "tensorflow/compiler/xla/stream_executor/fft.h" |
| #include "tensorflow/compiler/xla/stream_executor/kernel.h" |
| #include "tensorflow/compiler/xla/stream_executor/launch_dim.h" |
| #include "tensorflow/compiler/xla/stream_executor/platform/port.h" |
| #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" |
| #include "tensorflow/compiler/xla/stream_executor/temporary_memory_manager.h" |
| |
| namespace stream_executor { |
| |
| namespace host { |
| class HostBlas; |
| class HostFft; |
| class HostRng; |
| class HostTimer; |
| } // namespace host |
| |
| namespace ocl { |
| class CLBlas; |
| } // namespace ocl |
| |
| namespace internal { |
| class StreamInterface; |
| } // namespace internal |
| |
| class DeviceMemoryBase; |
| template <typename ElemT> |
| class DeviceMemory; |
| |
| class Timer; |
| |
| namespace dnn { |
| class BatchDescriptor; |
| class FilterDescriptor; |
| class ConvolutionDescriptor; |
| class ProfileResult; |
| class AlgorithmDesc; |
| } // namespace dnn |
| |
| class StreamExecutor; |
| class ScratchAllocator; |
| |
| namespace detail { |
| |
| // Helper class to prevent a template function argument from being deduced. This |
| // is identical to std::type_identity in C++20. |
| template <typename T> |
| struct NonDeduced { |
| using type = T; |
| }; |
| template <typename T> |
| using NonDeducedType = typename NonDeduced<T>::type; |
| |
| // Helper to return if `T` is the same type as `First` or any or `Rest`. |
| template <typename T> |
| constexpr bool is_any_of() { |
| return false; |
| } |
| |
| template <typename T, typename First, typename... Rest> |
| constexpr bool is_any_of() { |
| return std::is_same_v<T, First> || is_any_of<T, Rest...>(); |
| } |
| |
| } // namespace detail |
| |
| // Convert a type to the corresponding QuantizedActivationMode. |
| template <typename ElementType> |
| struct Quantization; |
| |
| // Represents a stream of dependent computations on a GPU device. |
| // |
| // The operations within a stream execute linearly and asynchronously until |
| // BlockHostUntilDone() is invoked, which synchronously joins host code with |
| // the execution of the stream. |
| // |
| // If any given operation fails when entraining work for the stream, ok() will |
| // indicate that an error has occurred. After initialization, once a stream is |
| // !ok(), it will never be ok(). |
| // |
| // Thread-safe post-initialization. |
| class Stream { |
| public: |
| // Instantiate a stream tied to parent as a platform executor. Work |
| // entrained onto this stream will be launched/managed on that |
| // StreamExecutor's platform. |
| explicit Stream(StreamExecutor *parent); |
| |
| // Deallocates any stream resources that the parent StreamExecutor has |
| // bestowed |
| // upon this object. |
| ~Stream(); |
| |
| // Returns whether any errors have occurred while entraining work for this |
| // stream. |
| bool ok() const { return !InErrorState(); } |
| |
| // Retrieves execution status back into the stream from the underlying |
| // implementation without blocking the stream. |
| // |
| // Normally, Stream::BlockHostUntilDone is used to get execution status. |
| // However, some devices use out-of-band mechnanisms to ensure their streams |
| // have finished on-device work, without needing to block the streams. (These |
| // devices should also override AllowsSyncOnCompletion to return false.) For |
| // these devices, this method can be used after work is finished to retrieve |
| // execution status. |
| port::Status RefreshStatus() TF_LOCKS_EXCLUDED(mu_); |
| |
| // Initialize the stream. This must be performed before entraining any other |
| // operations. |
| Stream &Init() TF_LOCKS_EXCLUDED(mu_); |
| |
| // Initializes timer t via the StreamExecutor. |
| Stream &InitTimer(Timer *t); |
| |
| // Convenience wrapper around Init() and InitTimer(). |
| Stream &InitWithTimer(Timer *t); |
| |
| // Get or create a sub-stream from this stream. If there is any sub-stream in |
| // the pool that can be reused then just return this sub-stream. Otherwise |
| // create a new sub-stream. |
| // |
| // TODO(b/112196569): The semantics of failed sub-streams is error-prone. |
| Stream *GetOrCreateSubStream() TF_LOCKS_EXCLUDED(mu_); |
| |
| // Return the sub-stream back to the host stream so that it can be reused |
| // later. Sub-streams that are !ok() will not be reused. |
| // |
| // TODO(b/112196569): The semantics of failed sub-streams is error-prone. |
| void ReturnSubStream(Stream *sub_stream) TF_LOCKS_EXCLUDED(mu_); |
| |
| // Allocate temporary memories. The stream will deallocate them when blocked |
| // or destroyed. |
| template <typename T> |
| port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>> |
| AllocateTemporaryArray(uint64_t element_count); |
| |
| // Entrains onto the stream of operations: a kernel launch with the given |
| // (variadic) parameters for the invocation. These arguments can be things |
| // like DeviceMemory or primitive types such as int. What arguments you may |
| // pass to a given kernel are noted as the template parameters to the |
| // TypedKernel type that the machocc compiler generates. |
| // |
| // Template parameters: |
| // Params... The type list of formal parameters that the typed kernel |
| // expects, which is matched against Args... |
| // Args... The deduced type list for passed actual arguments |
| // |
| // Implementation: A compile-time compatibility check is performed that has |
| // some leniency versus an exact parameter pack match -- for example, |
| // `const DeviceMemory<T>` is considered "pack compatible" with a |
| // `const DeviceMemory<T>&` formal parameter; in part, because we don't have |
| // perfect forwarding support without rvalue references. It also attempts to |
| // spit out helpful static_assert error traces with information as to the |
| // argument number and types that were mismatched. |
| template <typename... Params, typename... Args> |
| port::Status ThenLaunch(ThreadDim thread_dims, BlockDim block_dims, |
| const TypedKernel<Params...> &kernel, Args... args); |
| |
| // Record a "start" event for the interval timer at this point in the |
| // stream's execution (relative to the previously and subsequently enqueued |
| // items in the stream's execution). Streams may be started/stopped multiple |
| // times. |
| Stream &ThenStartTimer(Timer *t); |
| |
| // Record a "stop" event for the interval timer at this point in the |
| // stream's execution. See also Stream::ThenStartTimer. |
| Stream &ThenStopTimer(Timer *t); |
| |
| // TODO(leary) If work is added to the stream that is being depended upon, |
| // then what? Have to describe what happens. |
| template <typename... Params> |
| Stream &ThenWaitFor(Stream *other, Params... more_streams) { |
| return ThenWaitFor(more_streams...).ThenWaitFor(other); |
| } |
| |
| // Create a dependency for this stream's next work on the other stream |
| // completing. Does not take ownership of other, and other must not be |
| // null. |
| // |
| // Checks that a stream does not wait for itself, and it is up to the |
| // user to guarantee that a stream does not come to wait on itself in a |
| // cyclic manner; in that case, behavior is undefined. |
| // |
| // N.B. Base recursion case for the variadic ThenWaitFor. |
| Stream &ThenWaitFor(Stream *other); |
| |
| // Waits for all streams values in others. |
| // Checks that there is no shallow circular wait (i.e. that "this" is not in |
| // others) |
| template <typename P> |
| Stream &ThenWaitFor(P others) { |
| for (auto &stream : *others) { |
| CHECK_NE(stream.get(), this); |
| ThenWaitFor(stream.get()); |
| } |
| return *this; |
| } |
| |
| // Waits for an event object to be set. |
| // Note that ThenRecordEvent must have been called on the event before |
| // you call this function; otherwise the event will be considered complete |
| // and this wait will do nothing. |
| Stream &ThenWaitFor(Event *event); |
| |
| // Inserts the specified event into the end of this stream. Once the stream |
| // has processed all events prior to the insertion point, the event will be |
| // marked as completed. |
| // The stream does not take ownership of event - meaning that event's lifetime |
| // must extend past the point at which it is marked complete! |
| Stream &ThenRecordEvent(Event *event); |
| |
| //////////////// |
| // DNN support |
| // |
| // See DnnSupport::* for comments on the following methods. |
| |
| Stream &ThenBatchNormalizationForward( |
| const DeviceMemory<float> &x, const DeviceMemory<float> &scale, |
| const DeviceMemory<float> &offset, |
| const DeviceMemory<float> &estimated_mean, |
| const DeviceMemory<float> &estimated_variance, |
| const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc, |
| const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, |
| const double exponential_average_factor, |
| dnn::ActivationMode activation_mode, DeviceMemory<float> *y, |
| DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var, |
| DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var, |
| bool is_training, ScratchAllocator *reserve_space_allocator, |
| ScratchAllocator *workspace_allocator); |
| |
| Stream &ThenBatchNormalizationBackward( |
| const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x, |
| const DeviceMemory<float> &scale, const DeviceMemory<float> &offset, |
| const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var, |
| const DeviceMemory<float> &y, const dnn::BatchDescriptor &x_desc, |
| const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, |
| dnn::ActivationMode activation_mode, DeviceMemory<float> *x_backprop, |
| DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop, |
| DeviceMemory<float> *side_input_backprop, |
| DeviceMemory<uint8> *reserve_space_data, |
| ScratchAllocator *workspace_allocator); |
| |
| Stream &ThenBatchNormalizationForward( |
| const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale, |
| const DeviceMemory<float> &offset, |
| const DeviceMemory<float> &estimated_mean, |
| const DeviceMemory<float> &estimated_variance, |
| const DeviceMemory<Eigen::half> &side_input, |
| const dnn::BatchDescriptor &x_desc, |
| const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, |
| const double exponential_average_factor, |
| dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y, |
| DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var, |
| DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var, |
| bool is_training, ScratchAllocator *reserve_space_allocator, |
| ScratchAllocator *workspace_allocator); |
| |
| Stream &ThenBatchNormalizationBackward( |
| const DeviceMemory<Eigen::half> &y_backprop, |
| const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale, |
| const DeviceMemory<float> &offset, const DeviceMemory<float> &mean, |
| const DeviceMemory<float> &inv_var, const DeviceMemory<Eigen::half> &y, |
| const dnn::BatchDescriptor &x_desc, |
| const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, |
| dnn::ActivationMode activation_mode, |
| DeviceMemory<Eigen::half> *x_backprop, |
| DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop, |
| DeviceMemory<Eigen::half> *side_input_backprop, |
| DeviceMemory<uint8> *reserve_space_data, |
| ScratchAllocator *workspace_allocator); |
| |
| Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor, |
| const DeviceMemory<float> &input_data, |
| const dnn::FilterDescriptor &filter_descriptor, |
| const DeviceMemory<float> &filter_data, |
| const dnn::ConvolutionDescriptor &convolution_descriptor, |
| const dnn::BatchDescriptor &output_descriptor, |
| DeviceMemory<float> *output); |
| |
| Stream &ThenConvolveQuantized( |
| const dnn::BatchDescriptor &input_descriptor, |
| const DeviceMemory<float> &input_data, |
| const dnn::FilterDescriptor &filter_descriptor, |
| const DeviceMemory<int8> &filter_coefficients, |
| const DeviceMemory<float> &coefficient_scales, |
| const dnn::ConvolutionDescriptor &convolution_descriptor, |
| const dnn::BatchDescriptor &output_descriptor, |
| DeviceMemory<float> *output_data); |
| |
| Stream &ThenConvolveQuantized( |
| const dnn::BatchDescriptor &input_descriptor, |
| const DeviceMemory<float> &input_data, |
| const dnn::FilterDescriptor &filter_descriptor, |
| const DeviceMemory<int16> &filter_coefficients, |
| const DeviceMemory<float> &coefficient_scales, |
| const dnn::ConvolutionDescriptor &convolution_descriptor, |
| const dnn::BatchDescriptor &output_descriptor, |
| DeviceMemory<float> *output_data); |
| |
| template <typename InputType, typename OutputType> |
| port::Status ConvolveWithAlgorithm( |
| dnn::ConvolutionKind kind, const dnn::BatchDescriptor &input_descriptor, |
| DeviceMemory<InputType> input_data, |
| const dnn::FilterDescriptor &filter_descriptor, |
| DeviceMemory<InputType> filter_data, |
| const dnn::BatchDescriptor &output_descriptor, |
| DeviceMemory<OutputType> output_data, |
| const dnn::ConvolutionDescriptor &convolution_descriptor, |
| ScratchAllocator *scratch_allocator, |
| const dnn::AlgorithmConfig &algorithm_config, |
| dnn::ProfileResult *output_profile_result) { |
| DeviceMemory<uint8> scratch_memory; |
| dnn::AlgorithmDesc algorithm_desc; |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| TF_RETURN_IF_ERROR(dnn->PrepareForConvolution( |
| kind, this, input_descriptor, input_data, filter_descriptor, |
| filter_data, output_descriptor, output_data, convolution_descriptor, |
| algorithm_config, scratch_allocator, &algorithm_desc, |
| &scratch_memory)); |
| return dnn->DoConvolve(kind, dnn::ToDataType<InputType>::value, |
| dnn::ToDataType<OutputType>::value, this, |
| input_descriptor, input_data, filter_descriptor, |
| filter_data, output_descriptor, output_data, |
| convolution_descriptor, algorithm_desc, |
| scratch_memory, output_profile_result); |
| } |
| return port::UnimplementedError("DNN library is not found."); |
| } |
| |
| template <typename InputT, typename ScaleT, typename SideInputT, |
| typename BiasT, typename OutputT> |
| port::Status FusedConvolveWithAlgorithm( |
| const dnn::BatchDescriptor &conv_input_descriptor, |
| const DeviceMemory<InputT> &conv_input_data, ScaleT conv_input_scale, |
| const dnn::FilterDescriptor &filter_descriptor, |
| const DeviceMemory<InputT> &filter_data, |
| const dnn::ConvolutionDescriptor &convolution_descriptor, |
| const DeviceMemory<SideInputT> &side_input_data, ScaleT side_input_scale, |
| const dnn::BatchDescriptor &bias_descriptor, |
| const DeviceMemory<BiasT> &biases, dnn::ActivationMode activation_mode, |
| const dnn::BatchDescriptor &output_descriptor, |
| DeviceMemory<OutputT> *output, ScratchAllocator *scratch_allocator, |
| const dnn::AlgorithmConfig &algorithm_config, |
| dnn::ProfileResult *output_profile_result) { |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| return dnn->DoFusedConvolve( |
| this, dnn::ToDataType<InputT>::value, |
| dnn::ToDataType<SideInputT>::value, dnn::ToDataType<BiasT>::value, |
| dnn::ToDataType<OutputT>::value, conv_input_descriptor, |
| conv_input_data, conv_input_scale, filter_descriptor, filter_data, |
| convolution_descriptor, side_input_data, side_input_scale, |
| bias_descriptor, biases, activation_mode, output_descriptor, *output, |
| scratch_allocator, algorithm_config, output_profile_result); |
| } |
| return port::UnimplementedError("DNN library is not found."); |
| } |
| |
| port::StatusOr<std::unique_ptr<const dnn::ConvRunner>> ConvolveRunnerFromDesc( |
| const dnn::AlgorithmDesc &algorithm_desc, dnn::ConvolutionKind kind, |
| dnn::DataType element_type, dnn::DataType output_type, |
| const dnn::BatchDescriptor &input_descriptor, |
| const dnn::FilterDescriptor &filter_descriptor, |
| const dnn::BatchDescriptor &output_descriptor, |
| const dnn::ConvolutionDescriptor &convolution_descriptor) { |
| dnn::DnnSupport *dnn_support = parent_->AsDnn(); |
| if (!dnn_support) { |
| return port::UnimplementedError("DNN library is not found."); |
| } |
| return dnn_support->ConvolveRunnerFromDesc( |
| this, algorithm_desc, kind, element_type, output_type, input_descriptor, |
| filter_descriptor, output_descriptor, convolution_descriptor); |
| } |
| |
| port::StatusOr<std::unique_ptr<const dnn::FusedConvRunner>> |
| FusedConvolveRunnerFromDesc( |
| const dnn::AlgorithmDesc &algorithm_desc, dnn::ConvolutionKind kind, |
| dnn::DataType element_type, dnn::DataType bias_type, |
| dnn::DataType output_type, double conv_input_scale, |
| double side_input_scale, double leakyrelu_alpha, |
| const dnn::BatchDescriptor &input_descriptor, |
| const dnn::FilterDescriptor &filter_descriptor, |
| const dnn::BatchDescriptor &bias_descriptor, |
| const dnn::BatchDescriptor &output_descriptor, |
| const dnn::ConvolutionDescriptor &convolution_descriptor, |
| dnn::ActivationMode activation_mode) { |
| dnn::DnnSupport *dnn_support = parent_->AsDnn(); |
| if (!dnn_support) { |
| return port::UnimplementedError("DNN library is not found."); |
| } |
| return dnn_support->FusedConvolveRunnerFromDesc( |
| this, algorithm_desc, kind, element_type, bias_type, output_type, |
| conv_input_scale, side_input_scale, leakyrelu_alpha, input_descriptor, |
| filter_descriptor, bias_descriptor, output_descriptor, |
| convolution_descriptor, activation_mode); |
| } |
| |
| Stream &ThenSeparableConvolve( |
| const dnn::BatchDescriptor &input_descriptor, |
| const DeviceMemory<float> &input_data, |
| const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier, |
| const DeviceMemory<float> &first_weights, |
| const DeviceMemory<float> &second_weights, |
| const dnn::ConvolutionDescriptor &convolution_descriptor, |
| const dnn::BatchDescriptor &output_descriptor, |
| DeviceMemory<float> *output); |
| |
| Stream &ThenMatMul(const DeviceMemory<float> &input_data, |
| const DeviceMemory<float> &weights, |
| const dnn::BatchDescriptor &input_dimensions, |
| const dnn::BatchDescriptor &output_dimensions, |
| DeviceMemory<float> *output_data); |
| |
| Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data, |
| const DeviceMemory<int8> &weights, |
| const DeviceMemory<float> &weight_scales, |
| const dnn::BatchDescriptor &input_dimensions, |
| const dnn::BatchDescriptor &output_dimensions, |
| DeviceMemory<float> *output_data); |
| |
| Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data, |
| const DeviceMemory<int16> &weights, |
| const DeviceMemory<float> &weight_scales, |
| const dnn::BatchDescriptor &input_dimensions, |
| const dnn::BatchDescriptor &output_dimensions, |
| DeviceMemory<float> *output_data); |
| |
| Stream &ThenBiasAdd(const DeviceMemory<float> &input_data, |
| const DeviceMemory<float> &biases, |
| const dnn::BatchDescriptor &dimensions, |
| DeviceMemory<float> *output_data); |
| |
| template <typename ElementType> |
| port::Status ThenPoolForward( |
| const dnn::PoolingDescriptor &pooling_dimensions, |
| const dnn::BatchDescriptor &input_dimensions, |
| const DeviceMemory<ElementType> &input_data, |
| const dnn::BatchDescriptor &output_dimensions, |
| DeviceMemory<ElementType> *output_data, |
| ScratchAllocator *workspace_allocator = nullptr) { |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| return dnn->DoPoolForward(dnn::ToDataType<ElementType>::value, this, |
| pooling_dimensions, input_dimensions, |
| input_data, output_dimensions, *output_data, |
| workspace_allocator); |
| } |
| return port::UnimplementedError("DNN library is not found."); |
| } |
| |
| template <typename ElementType> |
| port::Status ThenPoolBackward( |
| const dnn::PoolingDescriptor &pooling_dimensions, |
| const dnn::BatchDescriptor &input_dimensions, |
| const DeviceMemory<ElementType> &input_data, |
| const dnn::BatchDescriptor &output_dimensions, |
| const DeviceMemory<ElementType> &output_data, |
| const DeviceMemory<ElementType> &input_diff_data, |
| DeviceMemory<ElementType> *output_diff_data, |
| ScratchAllocator *workspace_allocator = nullptr) { |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| return dnn->DoPoolBackward( |
| dnn::ToDataType<ElementType>::value, this, pooling_dimensions, |
| input_dimensions, input_data, output_dimensions, output_data, |
| input_diff_data, *output_diff_data, workspace_allocator); |
| } |
| return port::UnimplementedError("DNN library is not found."); |
| } |
| |
| Stream &ThenNormalizeWithDimensions( |
| const dnn::NormalizeDescriptor &normalize_descriptor, |
| const dnn::BatchDescriptor &dimensions, |
| const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data); |
| |
| Stream &ThenNormalizeBackwardWithDimensions( |
| const dnn::NormalizeDescriptor &normalize_descriptor, |
| const dnn::BatchDescriptor &dimensions, |
| const DeviceMemory<float> &raw_data, |
| const DeviceMemory<float> &normalized_data, |
| const DeviceMemory<float> &normalized_variable_gradient, |
| DeviceMemory<float> *raw_variable_gradient, |
| ScratchAllocator *workspace_allocator = nullptr); |
| |
| Stream &ThenActivate(dnn::ActivationMode activation_mode, |
| const dnn::BatchDescriptor &dimensions, |
| const DeviceMemory<float> &input_data, |
| DeviceMemory<float> *output_data); |
| |
| // Same as ThenActivate, but also takes an options argument that can be used |
| // for platform-specific option flags. |
| Stream &ThenActivateWithOptions(dnn::ActivationMode activation_mode, |
| const dnn::BatchDescriptor &dimensions, |
| const DeviceMemory<float> &input_data, |
| DeviceMemory<float> *output_data, |
| uint64_t options); |
| |
| Stream &ThenDepthConcatenate( |
| absl::Span<const dnn::BatchDescriptor> input_dimensions, |
| absl::Span<const DeviceMemory<float> *const> input_data, |
| DeviceMemory<float> *output_data); |
| |
| Stream &ThenSpaceConcatenate( |
| absl::Span<const dnn::BatchDescriptor> input_dimensions, |
| absl::Span<const DeviceMemory<float> *const> input_data, |
| DeviceMemory<float> *output_data, |
| dnn::SpaceConcatenateMode concat_direction); |
| |
| // Change the layout of the data by shrinking one dimension (or set of |
| // dimensions) and growing another dimension (or set of dimensions), while |
| // keeping the total number of data elements constant, and maintaining the |
| // current data ordering. |
| Stream &ThenReshape(const dnn::BatchDescriptor &input_dimensions, |
| const DeviceMemory<float> &input_data, |
| const dnn::BatchDescriptor &output_dimensions, |
| DeviceMemory<float> *output_data); |
| |
| // Depth to space takes an X by Y image with depth D*M² and changes it to an |
| // MX x MY image with depth D. Each input location (x,y) with depth D*M² in |
| // the input image is changed to an MxM contiguous area in the output image, |
| // with the values being laid out in raster order specified by |
| // DepthToSpaceLayout, and will have a new depth of D. |
| // See the DoDepthToSpace comment for more information. |
| Stream &ThenDepthToSpace(const dnn::BatchDescriptor &input_dimensions, |
| const DeviceMemory<float> &input_data, |
| const dnn::DepthToSpaceLayout &depth_to_space_layout, |
| const int sqrt_depth_reduction, |
| DeviceMemory<float> *output_data); |
| |
| // Space to depth is the inverse of depth to space. Space to depth takes each |
| // non-overlapping M by M patch (in the X and Y dimensions) with depth D of |
| // the input, and transforms it to a 1 by 1 patch with depth D*M². If the |
| // input has size (MX, MY, D), the output has size (X, Y, D*M²). The number of |
| // data elements is not changed. |
| Stream &ThenSpaceToDepth(const dnn::BatchDescriptor &input_dimensions, |
| const DeviceMemory<float> &input_data, |
| const dnn::DepthToSpaceLayout &space_to_depth_layout, |
| const int sqrt_depth_increase, |
| DeviceMemory<float> *output_data); |
| |
| Stream &ThenElementwiseOperate( |
| dnn::ElementwiseOperation operation, |
| absl::Span<const dnn::BatchDescriptor> input_dimensions, |
| absl::Span<const DeviceMemory<float> *const> input_data, |
| const dnn::BatchDescriptor &output_dimensions, |
| DeviceMemory<float> *output_data); |
| |
| Stream &ThenElementwiseOperateScaledQuantized( |
| dnn::ElementwiseOperation operation, |
| absl::Span<const int> input_multiplicands, int output_divisor, |
| absl::Span<const dnn::BatchDescriptor> input_dimensions, |
| absl::Span<const DeviceMemory<float> *const> input_data, |
| const dnn::BatchDescriptor &output_dimensions, |
| DeviceMemory<float> *output_data); |
| |
| Stream &ThenXYPad(const dnn::BatchDescriptor &dimensions, |
| const DeviceMemory<float> &input_data, int64_t left_pad, |
| int64_t right_pad, int64_t top_pad, int64_t bottom_pad, |
| DeviceMemory<float> *output_data); |
| |
| Stream &ThenXYSlice(const dnn::BatchDescriptor &dimensions, |
| const DeviceMemory<float> &input_data, int64_t left_trim, |
| int64_t right_trim, int64_t top_trim, int64_t bottom_trim, |
| DeviceMemory<float> *output_data); |
| |
| // Grows the input tensor by replicating the X and Y dimensions. The batch and |
| // depth/feature_map dimensions are unchanged. Currently, the input tensor is |
| // limited to X=1 and Y=1. |
| Stream &ThenXYBroadcast(const dnn::BatchDescriptor &dimensions, |
| const DeviceMemory<float> &input_data, |
| int64_t replicate_x, int64_t replicate_y, |
| DeviceMemory<float> *output_data); |
| |
| // See DnnSupport::DoMemcpyD2HQuantized. |
| Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src, |
| dnn::QuantizedActivationMode mode, |
| void *host_dst, uint64_t size); |
| |
| // Template version of ThenMemcpyD2HQuantized that takes a Span |
| // and uses the Quantization trait to call the generic version of |
| // ThenMemcpyD2HQuantized with the correct QuantizedActivationMode. |
| template <typename ElementType> |
| Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src, |
| absl::Span<ElementType> host_dst) { |
| return ThenMemcpyD2HQuantized( |
| gpu_unquantized_src, Quantization<ElementType>::kModeId, |
| host_dst.data(), host_dst.size() * sizeof(ElementType)); |
| } |
| |
| // See DnnSupport::DoMemcpyH2DQuantized. |
| Stream &ThenMemcpyH2DQuantized(const void *host_src, uint64_t size, |
| dnn::QuantizedActivationMode mode, |
| DeviceMemory<float> *gpu_unquantized_dst); |
| |
| // Template version of ThenMemcpyH2DQuantized that takes an array slice |
| // and uses the Quantization trait to call the generic version of |
| // ThenMemcpyH2DQuantized with the correct QuantizedActivationMode. |
| template <typename ElementType> |
| Stream &ThenMemcpyH2DQuantized(absl::Span<const ElementType> host_src, |
| DeviceMemory<float> *gpu_unquantized_dst) { |
| return ThenMemcpyH2DQuantized( |
| host_src.data(), host_src.size() * sizeof(ElementType), |
| Quantization<ElementType>::kModeId, gpu_unquantized_dst); |
| } |
| |
| // See DnnSupport::DoCopyHostBuffer2Device. |
| Stream &ThenCopyHostBuffer2Device(HostBuffer *buffer_src, |
| DeviceMemory<float> *gpu_unquantized_dst); |
| |
| // See DnnSupport::DoCopyDevice2HostBuffer. |
| Stream &ThenCopyDevice2HostBuffer( |
| const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst); |
| |
| ///////////////// |
| // BLAS support |
| |
| // See BlasSupport::DoBlasAsum. |
| Stream &ThenBlasAsum(uint64_t elem_count, const DeviceMemory<float> &x, |
| int incx, DeviceMemory<float> *result); |
| Stream &ThenBlasAsum(uint64_t elem_count, const DeviceMemory<double> &x, |
| int incx, DeviceMemory<double> *result); |
| Stream &ThenBlasAsum(uint64_t elem_count, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| DeviceMemory<float> *result); |
| Stream &ThenBlasAsum(uint64_t elem_count, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| DeviceMemory<double> *result); |
| |
| // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is |
| // present in DeviceMemory, it must be an execution-time constant (i.e. a |
| // value |
| // that the stream does not change or populate during the course of |
| // execution). The value is effectively captured at stream-enqueue time. |
| Stream &ThenBlasAxpy(uint64_t elem_count, float alpha, |
| const DeviceMemory<float> &x, int incx, |
| DeviceMemory<float> *y, int incy); |
| Stream &ThenBlasAxpy(uint64_t elem_count, double alpha, |
| const DeviceMemory<double> &x, int incx, |
| DeviceMemory<double> *y, int incy); |
| Stream &ThenBlasAxpy(uint64_t elem_count, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| DeviceMemory<std::complex<float>> *y, int incy); |
| Stream &ThenBlasAxpy(uint64_t elem_count, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| DeviceMemory<std::complex<double>> *y, int incy); |
| |
| // See BlasSupport::DoBlasCopy. |
| Stream &ThenBlasCopy(uint64_t elem_count, const DeviceMemory<float> &x, |
| int incx, DeviceMemory<float> *y, int incy); |
| Stream &ThenBlasCopy(uint64_t elem_count, const DeviceMemory<double> &x, |
| int incx, DeviceMemory<double> *y, int incy); |
| Stream &ThenBlasCopy(uint64_t elem_count, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| DeviceMemory<std::complex<float>> *y, int incy); |
| Stream &ThenBlasCopy(uint64_t elem_count, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| DeviceMemory<std::complex<double>> *y, int incy); |
| |
| // See BlasSupport::DoBlasDot. |
| Stream &ThenBlasDot(uint64_t elem_count, const DeviceMemory<float> &x, |
| int incx, const DeviceMemory<float> &y, int incy, |
| DeviceMemory<float> *result); |
| Stream &ThenBlasDot(uint64_t elem_count, const DeviceMemory<double> &x, |
| int incx, const DeviceMemory<double> &y, int incy, |
| DeviceMemory<double> *result); |
| |
| // See BlasSupport::DoBlasDotc. |
| Stream &ThenBlasDotc(uint64_t elem_count, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| const DeviceMemory<std::complex<float>> &y, int incy, |
| DeviceMemory<std::complex<float>> *result); |
| Stream &ThenBlasDotc(uint64_t elem_count, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| const DeviceMemory<std::complex<double>> &y, int incy, |
| DeviceMemory<std::complex<double>> *result); |
| |
| // See BlasSupport::DoBlasDotu. |
| Stream &ThenBlasDotu(uint64_t elem_count, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| const DeviceMemory<std::complex<float>> &y, int incy, |
| DeviceMemory<std::complex<float>> *result); |
| Stream &ThenBlasDotu(uint64_t elem_count, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| const DeviceMemory<std::complex<double>> &y, int incy, |
| DeviceMemory<std::complex<double>> *result); |
| |
| // See BlasSupport::DoBlasNrm2. |
| Stream &ThenBlasNrm2(uint64_t elem_count, const DeviceMemory<float> &x, |
| int incx, DeviceMemory<float> *result); |
| Stream &ThenBlasNrm2(uint64_t elem_count, const DeviceMemory<double> &x, |
| int incx, DeviceMemory<double> *result); |
| Stream &ThenBlasNrm2(uint64_t elem_count, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| DeviceMemory<float> *result); |
| Stream &ThenBlasNrm2(uint64_t elem_count, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| DeviceMemory<double> *result); |
| |
| // See BlasSupport::DoBlasRot. |
| Stream &ThenBlasRot(uint64_t elem_count, DeviceMemory<float> *x, int incx, |
| DeviceMemory<float> *y, int incy, float c, float s); |
| Stream &ThenBlasRot(uint64_t elem_count, DeviceMemory<double> *x, int incx, |
| DeviceMemory<double> *y, int incy, double c, double s); |
| Stream &ThenBlasRot(uint64_t elem_count, DeviceMemory<std::complex<float>> *x, |
| int incx, DeviceMemory<std::complex<float>> *y, int incy, |
| float c, float s); |
| Stream &ThenBlasRot(uint64_t elem_count, |
| DeviceMemory<std::complex<double>> *x, int incx, |
| DeviceMemory<std::complex<double>> *y, int incy, double c, |
| double s); |
| |
| // See BlasSupport::DoBlasRotg. |
| Stream &ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b, |
| DeviceMemory<float> *c, DeviceMemory<float> *s); |
| Stream &ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b, |
| DeviceMemory<double> *c, DeviceMemory<double> *s); |
| Stream &ThenBlasRotg(DeviceMemory<std::complex<float>> *a, |
| DeviceMemory<std::complex<float>> *b, |
| DeviceMemory<float> *c, |
| DeviceMemory<std::complex<float>> *s); |
| Stream &ThenBlasRotg(DeviceMemory<std::complex<double>> *a, |
| DeviceMemory<std::complex<double>> *b, |
| DeviceMemory<double> *c, |
| DeviceMemory<std::complex<double>> *s); |
| |
| // See BlasSupport::DoBlasRotm. |
| Stream &ThenBlasRotm(uint64_t elem_count, DeviceMemory<float> *x, int incx, |
| DeviceMemory<float> *y, int incy, |
| const DeviceMemory<float> ¶m); |
| Stream &ThenBlasRotm(uint64_t elem_count, DeviceMemory<double> *x, int incx, |
| DeviceMemory<double> *y, int incy, |
| const DeviceMemory<double> ¶m); |
| |
| // See BlasSupport::DoBlasRotmg. |
| Stream &ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2, |
| DeviceMemory<float> *x1, const DeviceMemory<float> &y1, |
| DeviceMemory<float> *param); |
| Stream &ThenBlasRotmg(DeviceMemory<double> *d1, DeviceMemory<double> *d2, |
| DeviceMemory<double> *x1, |
| const DeviceMemory<double> &y1, |
| DeviceMemory<double> *param); |
| |
| // See BlasSupport::DoBlasScal. |
| Stream &ThenBlasScal(uint64_t elem_count, float alpha, DeviceMemory<float> *x, |
| int incx); |
| Stream &ThenBlasScal(uint64_t elem_count, double alpha, |
| DeviceMemory<double> *x, int incx); |
| Stream &ThenBlasScal(uint64_t elem_count, float alpha, |
| DeviceMemory<std::complex<float>> *x, int incx); |
| Stream &ThenBlasScal(uint64_t elem_count, double alpha, |
| DeviceMemory<std::complex<double>> *x, int incx); |
| Stream &ThenBlasScal(uint64_t elem_count, std::complex<float> alpha, |
| DeviceMemory<std::complex<float>> *x, int incx); |
| Stream &ThenBlasScal(uint64_t elem_count, std::complex<double> alpha, |
| DeviceMemory<std::complex<double>> *x, int incx); |
| |
| // See BlasSupport::DoBlasSwap. |
| Stream &ThenBlasSwap(uint64_t elem_count, DeviceMemory<float> *x, int incx, |
| DeviceMemory<float> *y, int incy); |
| Stream &ThenBlasSwap(uint64_t elem_count, DeviceMemory<double> *x, int incx, |
| DeviceMemory<double> *y, int incy); |
| Stream &ThenBlasSwap(uint64_t elem_count, |
| DeviceMemory<std::complex<float>> *x, int incx, |
| DeviceMemory<std::complex<float>> *y, int incy); |
| Stream &ThenBlasSwap(uint64_t elem_count, |
| DeviceMemory<std::complex<double>> *x, int incx, |
| DeviceMemory<std::complex<double>> *y, int incy); |
| |
| // See BlasSupport::DoBlasIamax. |
| Stream &ThenBlasIamax(uint64_t elem_count, const DeviceMemory<float> &x, |
| int incx, DeviceMemory<int> *result); |
| Stream &ThenBlasIamax(uint64_t elem_count, const DeviceMemory<double> &x, |
| int incx, DeviceMemory<int> *result); |
| Stream &ThenBlasIamax(uint64_t elem_count, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| DeviceMemory<int> *result); |
| Stream &ThenBlasIamax(uint64_t elem_count, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| DeviceMemory<int> *result); |
| |
| // See BlasSupport::DoBlasIamin. |
| Stream &ThenBlasIamin(uint64_t elem_count, const DeviceMemory<float> &x, |
| int incx, DeviceMemory<int> *result); |
| Stream &ThenBlasIamin(uint64_t elem_count, const DeviceMemory<double> &x, |
| int incx, DeviceMemory<int> *result); |
| Stream &ThenBlasIamin(uint64_t elem_count, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| DeviceMemory<int> *result); |
| Stream &ThenBlasIamin(uint64_t elem_count, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| DeviceMemory<int> *result); |
| |
| // See BlasSupport::DoBlasGbmv. |
| Stream &ThenBlasGbmv(blas::Transpose trans, uint64_t m, uint64 n, uint64 kl, |
| uint64_t ku, float alpha, const DeviceMemory<float> &a, |
| int lda, const DeviceMemory<float> &x, int incx, |
| float beta, DeviceMemory<float> *y, int incy); |
| Stream &ThenBlasGbmv(blas::Transpose trans, uint64_t m, uint64 n, uint64 kl, |
| uint64_t ku, double alpha, const DeviceMemory<double> &a, |
| int lda, const DeviceMemory<double> &x, int incx, |
| double beta, DeviceMemory<double> *y, int incy); |
| Stream &ThenBlasGbmv(blas::Transpose trans, uint64_t m, uint64 n, uint64 kl, |
| uint64_t ku, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *y, int incy); |
| Stream &ThenBlasGbmv(blas::Transpose trans, uint64_t m, uint64 n, uint64 kl, |
| uint64_t ku, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *y, int incy); |
| |
| // See BlasSupport::DoBlasGemv. |
| Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, float alpha, |
| const DeviceMemory<float> &a, int lda, |
| const DeviceMemory<float> &x, int incx, float beta, |
| DeviceMemory<float> *y, int incy); |
| Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, |
| double alpha, const DeviceMemory<double> &a, int lda, |
| const DeviceMemory<double> &x, int incx, double beta, |
| DeviceMemory<double> *y, int incy); |
| Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *y, int incy); |
| Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *y, int incy); |
| |
| Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64_t m, uint64 n, |
| float alpha, const DeviceMemory<float> &a, |
| int lda, const DeviceMemory<float> &x, |
| int incx, float beta, |
| DeviceMemory<float> *y, int incy, |
| blas::ProfileResult *output_profile_result); |
| Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64_t m, uint64 n, |
| double alpha, const DeviceMemory<double> &a, |
| int lda, const DeviceMemory<double> &x, |
| int incx, double beta, |
| DeviceMemory<double> *y, int incy, |
| blas::ProfileResult *output_profile_result); |
| Stream &ThenBlasGemvWithProfiling( |
| blas::Transpose trans, uint64_t m, uint64 n, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy, |
| blas::ProfileResult *output_profile_result); |
| Stream &ThenBlasGemvWithProfiling( |
| blas::Transpose trans, uint64_t m, uint64 n, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| std::complex<double> beta, DeviceMemory<std::complex<double>> *y, |
| int incy, blas::ProfileResult *output_profile_result); |
| |
| // See BlasSupport::DoBlasGer. |
| Stream &ThenBlasGer(uint64_t m, uint64 n, float alpha, |
| const DeviceMemory<float> &x, int incx, |
| const DeviceMemory<float> &y, int incy, |
| DeviceMemory<float> *a, int lda); |
| Stream &ThenBlasGer(uint64_t m, uint64 n, double alpha, |
| const DeviceMemory<double> &x, int incx, |
| const DeviceMemory<double> &y, int incy, |
| DeviceMemory<double> *a, int lda); |
| |
| // See BlasSupport::DoBlasGerc. |
| Stream &ThenBlasGerc(uint64_t m, uint64 n, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| const DeviceMemory<std::complex<float>> &y, int incy, |
| DeviceMemory<std::complex<float>> *a, int lda); |
| Stream &ThenBlasGerc(uint64_t m, uint64 n, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| const DeviceMemory<std::complex<double>> &y, int incy, |
| DeviceMemory<std::complex<double>> *a, int lda); |
| |
| // See BlasSupport::DoBlasGeru. |
| Stream &ThenBlasGeru(uint64_t m, uint64 n, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| const DeviceMemory<std::complex<float>> &y, int incy, |
| DeviceMemory<std::complex<float>> *a, int lda); |
| Stream &ThenBlasGeru(uint64_t m, uint64 n, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| const DeviceMemory<std::complex<double>> &y, int incy, |
| DeviceMemory<std::complex<double>> *a, int lda); |
| |
| // See BlasSupport::DoBlasHbmv. |
| Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64_t n, uint64 k, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *y, int incy); |
| Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64_t n, uint64 k, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *y, int incy); |
| |
| // See BlasSupport::DoBlasHemv. |
| Stream &ThenBlasHemv(blas::UpperLower uplo, uint64_t n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *y, int incy); |
| Stream &ThenBlasHemv(blas::UpperLower uplo, uint64_t n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *y, int incy); |
| |
| // See BlasSupport::DoBlasHer. |
| Stream &ThenBlasHer(blas::UpperLower uplo, uint64_t n, float alpha, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| DeviceMemory<std::complex<float>> *a, int lda); |
| Stream &ThenBlasHer(blas::UpperLower uplo, uint64_t n, double alpha, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| DeviceMemory<std::complex<double>> *a, int lda); |
| |
| // See BlasSupport::DoBlasHer2. |
| Stream &ThenBlasHer2(blas::UpperLower uplo, uint64_t n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| const DeviceMemory<std::complex<float>> &y, int incy, |
| DeviceMemory<std::complex<float>> *a, int lda); |
| Stream &ThenBlasHer2(blas::UpperLower uplo, uint64_t n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| const DeviceMemory<std::complex<double>> &y, int incy, |
| DeviceMemory<std::complex<double>> *a, int lda); |
| |
| // See BlasSupport::DoBlasHpmv. |
| Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64_t n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &ap, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *y, int incy); |
| Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64_t n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &ap, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *y, int incy); |
| |
| // See BlasSupport::DoBlasHpr. |
| Stream &ThenBlasHpr(blas::UpperLower uplo, uint64_t n, float alpha, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| DeviceMemory<std::complex<float>> *ap); |
| Stream &ThenBlasHpr(blas::UpperLower uplo, uint64_t n, double alpha, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| DeviceMemory<std::complex<double>> *ap); |
| |
| // See BlasSupport::DoBlasHpr2. |
| Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64_t n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| const DeviceMemory<std::complex<float>> &y, int incy, |
| DeviceMemory<std::complex<float>> *ap); |
| Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64_t n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| const DeviceMemory<std::complex<double>> &y, int incy, |
| DeviceMemory<std::complex<double>> *ap); |
| |
| // See BlasSupport::DoBlasSbmv. |
| Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k, float alpha, |
| const DeviceMemory<float> &a, int lda, |
| const DeviceMemory<float> &x, int incx, float beta, |
| DeviceMemory<float> *y, int incy); |
| Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k, |
| double alpha, const DeviceMemory<double> &a, int lda, |
| const DeviceMemory<double> &x, int incx, double beta, |
| DeviceMemory<double> *y, int incy); |
| |
| // See BlasSupport::DoBlasSpmv. |
| Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64_t n, float alpha, |
| const DeviceMemory<float> &ap, |
| const DeviceMemory<float> &x, int incx, float beta, |
| DeviceMemory<float> *y, int incy); |
| Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64_t n, double alpha, |
| const DeviceMemory<double> &ap, |
| const DeviceMemory<double> &x, int incx, double beta, |
| DeviceMemory<double> *y, int incy); |
| |
| // See BlasSupport::DoBlasSpr. |
| Stream &ThenBlasSpr(blas::UpperLower uplo, uint64_t n, float alpha, |
| const DeviceMemory<float> &x, int incx, |
| DeviceMemory<float> *ap); |
| Stream &ThenBlasSpr(blas::UpperLower uplo, uint64_t n, double alpha, |
| const DeviceMemory<double> &x, int incx, |
| DeviceMemory<double> *ap); |
| |
| // See BlasSupport::DoBlasSpr2. |
| Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64_t n, float alpha, |
| const DeviceMemory<float> &x, int incx, |
| const DeviceMemory<float> &y, int incy, |
| DeviceMemory<float> *ap); |
| Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64_t n, double alpha, |
| const DeviceMemory<double> &x, int incx, |
| const DeviceMemory<double> &y, int incy, |
| DeviceMemory<double> *ap); |
| |
| // See BlasSupport::DoBlasSymv. |
| Stream &ThenBlasSymv(blas::UpperLower uplo, uint64_t n, float alpha, |
| const DeviceMemory<float> &a, int lda, |
| const DeviceMemory<float> &x, int incx, float beta, |
| DeviceMemory<float> *y, int incy); |
| Stream &ThenBlasSymv(blas::UpperLower uplo, uint64_t n, double alpha, |
| const DeviceMemory<double> &a, int lda, |
| const DeviceMemory<double> &x, int incx, double beta, |
| DeviceMemory<double> *y, int incy); |
| |
| // See BlasSupport::DoBlasSyr. |
| Stream &ThenBlasSyr(blas::UpperLower uplo, uint64_t n, float alpha, |
| const DeviceMemory<float> &x, int incx, |
| DeviceMemory<float> *a, int lda); |
| Stream &ThenBlasSyr(blas::UpperLower uplo, uint64_t n, double alpha, |
| const DeviceMemory<double> &x, int incx, |
| DeviceMemory<double> *a, int lda); |
| |
| // See BlasSupport::DoBlasSyr2. |
| Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64_t n, float alpha, |
| const DeviceMemory<float> &x, int incx, |
| const DeviceMemory<float> &y, int incy, |
| DeviceMemory<float> *a, int lda); |
| Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64_t n, double alpha, |
| const DeviceMemory<double> &x, int incx, |
| const DeviceMemory<double> &y, int incy, |
| DeviceMemory<double> *a, int lda); |
| |
| // See BlasSupport::DoBlasTbmv. |
| Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, uint64 k, |
| const DeviceMemory<float> &a, int lda, |
| DeviceMemory<float> *x, int incx); |
| Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, uint64 k, |
| const DeviceMemory<double> &a, int lda, |
| DeviceMemory<double> *x, int incx); |
| Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, uint64 k, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| DeviceMemory<std::complex<float>> *x, int incx); |
| Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, uint64 k, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| DeviceMemory<std::complex<double>> *x, int incx); |
| |
| // See BlasSupport::DoBlasTbsv. |
| Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, uint64 k, |
| const DeviceMemory<float> &a, int lda, |
| DeviceMemory<float> *x, int incx); |
| Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, uint64 k, |
| const DeviceMemory<double> &a, int lda, |
| DeviceMemory<double> *x, int incx); |
| Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, uint64 k, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| DeviceMemory<std::complex<float>> *x, int incx); |
| Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, uint64 k, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| DeviceMemory<std::complex<double>> *x, int incx); |
| |
| // See BlasSupport::DoBlasTpmv. |
| Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, |
| const DeviceMemory<float> &ap, DeviceMemory<float> *x, |
| int incx); |
| Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, |
| const DeviceMemory<double> &ap, DeviceMemory<double> *x, |
| int incx); |
| Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, |
| const DeviceMemory<std::complex<float>> &ap, |
| DeviceMemory<std::complex<float>> *x, int incx); |
| Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, |
| const DeviceMemory<std::complex<double>> &ap, |
| DeviceMemory<std::complex<double>> *x, int incx); |
| |
| // See BlasSupport::DoBlasTpsv. |
| Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, |
| const DeviceMemory<float> &ap, DeviceMemory<float> *x, |
| int incx); |
| Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, |
| const DeviceMemory<double> &ap, DeviceMemory<double> *x, |
| int incx); |
| Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, |
| const DeviceMemory<std::complex<float>> &ap, |
| DeviceMemory<std::complex<float>> *x, int incx); |
| Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, |
| const DeviceMemory<std::complex<double>> &ap, |
| DeviceMemory<std::complex<double>> *x, int incx); |
| |
| // See BlasSupport::DoBlasTrmv. |
| Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, |
| const DeviceMemory<float> &a, int lda, |
| DeviceMemory<float> *x, int incx); |
| Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, |
| const DeviceMemory<double> &a, int lda, |
| DeviceMemory<double> *x, int incx); |
| Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| DeviceMemory<std::complex<float>> *x, int incx); |
| Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| DeviceMemory<std::complex<double>> *x, int incx); |
| |
| // See BlasSupport::DoBlasTrsv. |
| Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, |
| const DeviceMemory<float> &a, int lda, |
| DeviceMemory<float> *x, int incx); |
| Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, |
| const DeviceMemory<double> &a, int lda, |
| DeviceMemory<double> *x, int incx); |
| Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| DeviceMemory<std::complex<float>> *x, int incx); |
| Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64_t n, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| DeviceMemory<std::complex<double>> *x, int incx); |
| |
| template <typename InputType> |
| port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, |
| uint64_t m, uint64 n, uint64 k, |
| const DeviceMemory<InputType> &a, int lda, |
| const DeviceMemory<InputType> &b, int ldb, |
| DeviceMemory<InputType> *c, int ldc, |
| blas::ComputePrecision precision) { |
| InputType alpha{1.0}; |
| InputType beta{0.0}; |
| return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, |
| ldc, precision); |
| } |
| |
| // TODO(parkers): Update all callers to pass kDefaultComputePrecision. |
| template <typename InputType> |
| port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, |
| uint64_t m, uint64 n, uint64 k, |
| const DeviceMemory<InputType> &a, int lda, |
| const DeviceMemory<InputType> &b, int ldb, |
| DeviceMemory<InputType> *c, int ldc) { |
| return ThenBlasGemm(transa, transb, m, n, k, a, lda, b, ldb, c, ldc, |
| blas::kDefaultComputePrecision); |
| } |
| |
| template <typename InputType, typename ConstantType> |
| port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, |
| uint64_t m, uint64 n, uint64 k, ConstantType alpha, |
| const DeviceMemory<InputType> &a, int lda, |
| const DeviceMemory<InputType> &b, int ldb, |
| ConstantType beta, DeviceMemory<InputType> *c, |
| int ldc, blas::ComputePrecision precision) { |
| static_assert( |
| detail::is_any_of<InputType, Eigen::half, Eigen::bfloat16, float, |
| double, std::complex<float>, std::complex<double>>(), |
| "Input can be half, bf16, float, double, std::complex<float> or " |
| "std::complex<double>"); |
| static_assert(!std::is_same_v<InputType, Eigen::half> || |
| detail::is_any_of<ConstantType, float, Eigen::half>(), |
| "If input is Eigen::half, constant has to be either " |
| "Eigen::half or float"); |
| static_assert( |
| detail::is_any_of<InputType, Eigen::half, ConstantType>(), |
| "If input is not Eigen::half, constant and input types have to match"); |
| blas::BlasSupport *blas = parent()->AsBlas(); |
| if (!blas) { |
| return port::InternalError( |
| "Attempting to perform BLAS operation using " |
| "StreamExecutor without BLAS support"); |
| } |
| |
| void *alpha_ptr = α |
| void *beta_ptr = β |
| float alpha_storage, beta_storage; |
| UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage, |
| &beta_storage); |
| |
| return blas->DoBlasGemm(this, transa, transb, m, n, k, |
| blas::ToDataType<InputType>::value, alpha_ptr, a, |
| lda, b, ldb, beta_ptr, c, ldc, precision); |
| } |
| |
| // TODO(parkers): Update all callers to pass kDefaultComputePrecision. |
| template <typename InputType, typename ConstantType> |
| port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, |
| uint64_t m, uint64 n, uint64 k, ConstantType alpha, |
| const DeviceMemory<InputType> &a, int lda, |
| const DeviceMemory<InputType> &b, int ldb, |
| ConstantType beta, DeviceMemory<InputType> *c, |
| int ldc) { |
| return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, |
| ldc, blas::kDefaultComputePrecision); |
| } |
| |
| Stream &ThenBlasGemmWithProfiling(blas::Transpose transa, |
| blas::Transpose transb, uint64_t m, |
| uint64 n, uint64_t k, float alpha, |
| const DeviceMemory<Eigen::half> &a, int lda, |
| const DeviceMemory<Eigen::half> &b, int ldb, |
| float beta, DeviceMemory<Eigen::half> *c, |
| int ldc, |
| blas::ProfileResult *output_profile_result); |
| Stream &ThenBlasGemmWithProfiling(blas::Transpose transa, |
| blas::Transpose transb, uint64_t m, |
| uint64 n, uint64_t k, float alpha, |
| const DeviceMemory<float> &a, int lda, |
| const DeviceMemory<float> &b, int ldb, |
| float beta, DeviceMemory<float> *c, int ldc, |
| blas::ProfileResult *output_profile_result); |
| Stream &ThenBlasGemmWithProfiling(blas::Transpose transa, |
| blas::Transpose transb, uint64_t m, |
| uint64 n, uint64_t k, double alpha, |
| const DeviceMemory<double> &a, int lda, |
| const DeviceMemory<double> &b, int ldb, |
| double beta, DeviceMemory<double> *c, |
| int ldc, |
| blas::ProfileResult *output_profile_result); |
| Stream &ThenBlasGemmWithProfiling( |
| blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
| uint64_t k, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &b, int ldb, |
| std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, |
| blas::ProfileResult *output_profile_result); |
| Stream &ThenBlasGemmWithProfiling( |
| blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
| uint64_t k, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &b, int ldb, |
| std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, |
| blas::ProfileResult *output_profile_result); |
| |
| template <typename InputType, typename OutputType> |
| port::Status ThenBlasGemmWithAlgorithm( |
| blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
| uint64_t k, const DeviceMemory<InputType> &a, int lda, |
| const DeviceMemory<InputType> &b, int ldb, DeviceMemory<OutputType> *c, |
| int ldc, blas::ComputationType computation_type, |
| blas::AlgorithmType algorithm, |
| blas::ProfileResult *output_profile_result) { |
| OutputType alpha{1}; |
| OutputType beta{0}; |
| return ThenBlasGemmWithAlgorithm(transa, transb, m, n, k, alpha, a, lda, b, |
| ldb, beta, c, ldc, computation_type, |
| algorithm, output_profile_result); |
| } |
| |
| template <typename InputType, typename OutputType, typename ConstantType> |
| port::Status ThenBlasGemmWithAlgorithm( |
| blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
| uint64_t k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda, |
| const DeviceMemory<InputType> &b, int ldb, ConstantType beta, |
| DeviceMemory<OutputType> *c, int ldc, |
| blas::ComputationType computation_type, blas::AlgorithmType algorithm, |
| blas::ProfileResult *output_profile_result) { |
| TF_RETURN_IF_ERROR( |
| CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>( |
| computation_type)); |
| |
| blas::BlasSupport *blas = parent()->AsBlas(); |
| if (!blas) { |
| return port::InternalError( |
| "Attempting to perform BLAS operation using " |
| "StreamExecutor without BLAS support"); |
| } |
| |
| void *alpha_ptr = α |
| void *beta_ptr = β |
| float alpha_storage, beta_storage; |
| UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage, |
| &beta_storage); |
| |
| port::Status st = blas->DoBlasGemmWithAlgorithm( |
| this, transa, transb, m, n, k, alpha_ptr, a, |
| blas::ToDataType<InputType>::value, lda, b, |
| blas::ToDataType<InputType>::value, ldb, beta_ptr, c, |
| blas::ToDataType<OutputType>::value, ldc, computation_type, algorithm, |
| output_profile_result); |
| if (output_profile_result) { |
| // The error is recorded in the profile. |
| return ::tensorflow::OkStatus(); |
| } |
| return st; |
| } |
| |
| template <typename InputType, typename OutputType, typename ConstantType> |
| port::Status ThenBlasGemmStridedBatchedWithAlgorithm( |
| blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
| uint64_t k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda, |
| int64_t stride_a, const DeviceMemory<InputType> &b, int ldb, |
| int64_t stride_b, ConstantType beta, DeviceMemory<OutputType> *c, int ldc, |
| int64_t stride_c, int batch_count, blas::ComputationType computation_type, |
| blas::AlgorithmType algorithm, |
| blas::ProfileResult *output_profile_result) { |
| TF_RETURN_IF_ERROR( |
| CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>( |
| computation_type)); |
| |
| blas::BlasSupport *blas = parent()->AsBlas(); |
| if (!blas) { |
| return port::InternalError( |
| "Attempting to perform BLAS operation using " |
| "StreamExecutor without BLAS support"); |
| } |
| void *alpha_ptr = α |
| void *beta_ptr = β |
| float alpha_storage, beta_storage; |
| UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage, |
| &beta_storage); |
| port::Status st = blas->DoBlasGemmStridedBatchedWithAlgorithm( |
| this, transa, transb, m, n, k, alpha_ptr, a, |
| blas::ToDataType<InputType>::value, lda, stride_a, b, |
| blas::ToDataType<InputType>::value, ldb, stride_b, beta_ptr, c, |
| blas::ToDataType<OutputType>::value, ldc, stride_c, batch_count, |
| computation_type, algorithm, output_profile_result); |
| if (output_profile_result) { |
| // The error is recorded in the profile. |
| return ::tensorflow::OkStatus(); |
| } |
| return st; |
| } |
| |
| // See BlasSupport::DoBlasGemmBatched. |
| Stream &ThenBlasGemmBatched( |
| blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
| uint64_t k, float alpha, |
| const absl::Span<DeviceMemory<Eigen::half> *const> a, int lda, |
| const absl::Span<DeviceMemory<Eigen::half> *const> b, int ldb, float beta, |
| const absl::Span<DeviceMemory<Eigen::half> *const> c, int ldc, |
| int batch_count); |
| Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, |
| uint64_t m, uint64 n, uint64 k, float alpha, |
| const absl::Span<DeviceMemory<float> *const> a, |
| int lda, |
| const absl::Span<DeviceMemory<float> *const> b, |
| int ldb, float beta, |
| const absl::Span<DeviceMemory<float> *const> c, |
| int ldc, int batch_count); |
| Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, |
| uint64_t m, uint64 n, uint64 k, double alpha, |
| const absl::Span<DeviceMemory<double> *const> a, |
| int lda, |
| const absl::Span<DeviceMemory<double> *const> b, |
| int ldb, double beta, |
| const absl::Span<DeviceMemory<double> *const> c, |
| int ldc, int batch_count); |
| Stream &ThenBlasGemmBatched( |
| blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
| uint64_t k, std::complex<float> alpha, |
| const absl::Span<DeviceMemory<std::complex<float>> *const> a, int lda, |
| const absl::Span<DeviceMemory<std::complex<float>> *const> b, int ldb, |
| std::complex<float> beta, |
| const absl::Span<DeviceMemory<std::complex<float>> *const> c, int ldc, |
| int batch_count); |
| Stream &ThenBlasGemmBatched( |
| blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
| uint64_t k, std::complex<double> alpha, |
| const absl::Span<DeviceMemory<std::complex<double>> *const> a, int lda, |
| const absl::Span<DeviceMemory<std::complex<double>> *const> b, int ldb, |
| std::complex<double> beta, |
| const absl::Span<DeviceMemory<std::complex<double>> *const> c, int ldc, |
| int batch_count); |
| Stream &ThenBlasGemmBatchedWithScratch( |
| blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
| uint64_t k, float alpha, |
| const absl::Span<DeviceMemory<Eigen::half> *const> a, int lda, |
| const absl::Span<DeviceMemory<Eigen::half> *const> b, int ldb, float beta, |
| const absl::Span<DeviceMemory<Eigen::half> *const> c, int ldc, |
| int batch_count, ScratchAllocator *scratch_allocator); |
| Stream &ThenBlasGemmBatchedWithScratch( |
| blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
| uint64_t k, float alpha, const absl::Span<DeviceMemory<float> *const> a, |
| int lda, const absl::Span<DeviceMemory<float> *const> b, int ldb, |
| float beta, const absl::Span<DeviceMemory<float> *const> c, int ldc, |
| int batch_count, ScratchAllocator *scratch_allocator); |
| Stream &ThenBlasGemmBatchedWithScratch( |
| blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
| uint64_t k, double alpha, const absl::Span<DeviceMemory<double> *const> a, |
| int lda, const absl::Span<DeviceMemory<double> *const> b, int ldb, |
| double beta, const absl::Span<DeviceMemory<double> *const> c, int ldc, |
| int batch_count, ScratchAllocator *scratch_allocator); |
| Stream &ThenBlasGemmBatchedWithScratch( |
| blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
| uint64_t k, std::complex<float> alpha, |
| const absl::Span<DeviceMemory<std::complex<float>> *const> a, int lda, |
| const absl::Span<DeviceMemory<std::complex<float>> *const> b, int ldb, |
| std::complex<float> beta, |
| const absl::Span<DeviceMemory<std::complex<float>> *const> c, int ldc, |
| int batch_count, ScratchAllocator *scratch_allocator); |
| Stream &ThenBlasGemmBatchedWithScratch( |
| blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
| uint64_t k, std::complex<double> alpha, |
| const absl::Span<DeviceMemory<std::complex<double>> *const> a, int lda, |
| const absl::Span<DeviceMemory<std::complex<double>> *const> b, int ldb, |
| std::complex<double> beta, |
| const absl::Span<DeviceMemory<std::complex<double>> *const> c, int ldc, |
| int batch_count, ScratchAllocator *scratch_allocator); |
| |
| template <typename InputType, typename ConstantType> |
| port::Status ThenBlasGemmStridedBatched( |
| blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
| uint64_t k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda, |
| int64_t stride_a, const DeviceMemory<InputType> &b, int ldb, |
| int64_t stride_b, ConstantType beta, DeviceMemory<InputType> *c, int ldc, |
| int64_t stride_c, int batch_count) { |
| static_assert( |
| detail::is_any_of<InputType, float, Eigen::half, Eigen::bfloat16, |
| double, std::complex<float>, std::complex<double>>(), |
| "Unsupported input type"); |
| static_assert( |
| std::is_same_v<ConstantType, InputType> || |
| (detail::is_any_of<InputType, Eigen::half, Eigen::bfloat16>() && |
| std::is_same_v<ConstantType, float>), |
| "Mismatched input and alpha/beta types"); |
| blas::BlasSupport *blas = parent()->AsBlas(); |
| if (!blas) { |
| return port::InternalError( |
| "Attempting to perform BLAS operation using " |
| "StreamExecutor without BLAS support"); |
| } |
| |
| void *alpha_ptr = α |
| void *beta_ptr = β |
| float alpha_storage, beta_storage; |
| UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage, |
| &beta_storage); |
| |
| return blas->DoBlasGemmStridedBatched( |
| this, transa, transb, m, n, k, blas::ToDataType<InputType>::value, |
| alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc, |
| stride_c, batch_count); |
| } |
| |
| // See BlasSupport::DoBlasHemm. |
| Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64_t m, |
| uint64_t n, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &b, int ldb, |
| std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *c, int ldc); |
| Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64_t m, |
| uint64_t n, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &b, int ldb, |
| std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *c, int ldc); |
| |
| // See BlasSupport::DoBlasHerk. |
| Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64_t n, |
| uint64_t k, float alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| float beta, DeviceMemory<std::complex<float>> *c, |
| int ldc); |
| Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64_t n, |
| uint64_t k, double alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| double beta, DeviceMemory<std::complex<double>> *c, |
| int ldc); |
| |
| // See BlasSupport::DoBlasHer2k. |
| Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, |
| uint64_t n, uint64_t k, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &b, int ldb, |
| float beta, DeviceMemory<std::complex<float>> *c, |
| int ldc); |
| Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, |
| uint64_t n, uint64_t k, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &b, int ldb, |
| double beta, DeviceMemory<std::complex<double>> *c, |
| int ldc); |
| |
| // See BlasSupport::DoBlasSymm. |
| Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64_t m, |
| uint64_t n, float alpha, const DeviceMemory<float> &a, |
| int lda, const DeviceMemory<float> &b, int ldb, |
| float beta, DeviceMemory<float> *c, int ldc); |
| Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64_t m, |
| uint64_t n, double alpha, const DeviceMemory<double> &a, |
| int lda, const DeviceMemory<double> &b, int ldb, |
| double beta, DeviceMemory<double> *c, int ldc); |
| Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64_t m, |
| uint64_t n, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &b, int ldb, |
| std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *c, int ldc); |
| Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64_t m, |
| uint64_t n, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &b, int ldb, |
| std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *c, int ldc); |
| |
| // See BlasSupport::DoBlasSyrk. |
| Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64_t n, |
| uint64_t k, float alpha, const DeviceMemory<float> &a, |
| int lda, float beta, DeviceMemory<float> *c, int ldc); |
| Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64_t n, |
| uint64_t k, double alpha, const DeviceMemory<double> &a, |
| int lda, double beta, DeviceMemory<double> *c, int ldc); |
| Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64_t n, |
| uint64_t k, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *c, int ldc); |
| Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64_t n, |
| uint64_t k, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *c, int ldc); |
| |
| // See BlasSupport::DoBlasSyr2k. |
| Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, |
| uint64_t n, uint64_t k, float alpha, |
| const DeviceMemory<float> &a, int lda, |
| const DeviceMemory<float> &b, int ldb, float beta, |
| DeviceMemory<float> *c, int ldc); |
| Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, |
| uint64_t n, uint64_t k, double alpha, |
| const DeviceMemory<double> &a, int lda, |
| const DeviceMemory<double> &b, int ldb, double beta, |
| DeviceMemory<double> *c, int ldc); |
| Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, |
| uint64_t n, uint64_t k, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &b, int ldb, |
| std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *c, int ldc); |
| Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, |
| uint64_t n, uint64_t k, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &b, int ldb, |
| std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *c, int ldc); |
| |
| // See BlasSupport::DoBlasTrmm. |
| Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, uint64_t m, |
| uint64_t n, float alpha, const DeviceMemory<float> &a, |
| int lda, DeviceMemory<float> *b, int ldb); |
| Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, uint64_t m, |
| uint64_t n, double alpha, const DeviceMemory<double> &a, |
| int lda, DeviceMemory<double> *b, int ldb); |
| Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, uint64_t m, |
| uint64_t n, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| DeviceMemory<std::complex<float>> *b, int ldb); |
| Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, uint64_t m, |
| uint64_t n, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| DeviceMemory<std::complex<double>> *b, int ldb); |
| |
| // See BlasSupport::DoBlasTrsm. |
| Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, uint64_t m, |
| uint64_t n, float alpha, const DeviceMemory<float> &a, |
| int lda, DeviceMemory<float> *b, int ldb); |
| Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, uint64_t m, |
| uint64_t n, double alpha, const DeviceMemory<double> &a, |
| int lda, DeviceMemory<double> *b, int ldb); |
| Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, uint64_t m, |
| uint64_t n, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| DeviceMemory<std::complex<float>> *b, int ldb); |
| Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, uint64_t m, |
| uint64_t n, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| DeviceMemory<std::complex<double>> *b, int ldb); |
| |
| // See BlasSupport::DoBlasTrsmBatched. |
| Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, |
| uint64_t m, uint64 n, float alpha, |
| const DeviceMemory<float *> &as, int lda, |
| DeviceMemory<float *> *bs, int ldb, |
| int batch_count); |
| Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, |
| uint64_t m, uint64 n, double alpha, |
| const DeviceMemory<double *> &as, int lda, |
| DeviceMemory<double *> *bs, int ldb, |
| int batch_count); |
| Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, |
| uint64_t m, uint64 n, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float> *> &as, |
| int lda, DeviceMemory<std::complex<float> *> *bs, |
| int ldb, int batch_count); |
| Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, |
| uint64_t m, uint64 n, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double> *> &as, |
| int lda, DeviceMemory<std::complex<double> *> *bs, |
| int ldb, int batch_count); |
| |
| // See FftSupport::DoFft. |
| Stream &ThenFft(fft::Plan *plan, |
| const DeviceMemory<std::complex<float>> &input, |
| DeviceMemory<std::complex<float>> *output); |
| Stream &ThenFft(fft::Plan *plan, |
| const DeviceMemory<std::complex<double>> &input, |
| DeviceMemory<std::complex<double>> *output); |
| Stream &ThenFft(fft::Plan *plan, const DeviceMemory<float> &input, |
| DeviceMemory<std::complex<float>> *output); |
| Stream &ThenFft(fft::Plan *plan, const DeviceMemory<double> &input, |
| DeviceMemory<std::complex<double>> *output); |
| Stream &ThenFft(fft::Plan *plan, |
| const DeviceMemory<std::complex<float>> &input, |
| DeviceMemory<float> *output); |
| Stream &ThenFft(fft::Plan *plan, |
| const DeviceMemory<std::complex<double>> &input, |
| DeviceMemory<double> *output); |
| |
| // Makes the RNG use the provided value as the basis for further generation. |
| // /dev/urandom (good) and /dev/random (better, but sometimes slow) are good |
| // sources of seed data if the default (high quality) sources are not |
| // desired. |
| // For most use cases, this function will not be necessary; each provided |
| // back-end implementation will be appropriately seeded by default. |
| // At a minimum 16 bytes of data are required in the seed buffer. |
| // |
| // To seed with good (non-reproducible) data: |
| // File* f = File::Open("/dev/random", "r"); |
| // int64_t bytes_read = f->Read(seed_data, bytes_to_read); |
| // < error checking > |
| // stream.ThenSetRngSeed(seed_data, bytes_read); |
| // |
| // To seed with reproducible data: |
| // uint64_t seed_data[2] = { <data> }; |
| // stream.ThenSetRngSeed(seed_data, 16); |
| Stream &ThenSetRngSeed(const uint8 *seed, uint64_t seed_bytes); |
| |
| // Populates the memory indicated by values with uniform-random-distribution |
| // values. TODO(leary) seeding API/description |
| // |
| // Uses the type and size of the DeviceMemory to infer what data should be |
| // populated. |
| Stream &ThenPopulateRandUniform(DeviceMemory<float> *values); |
| Stream &ThenPopulateRandUniform(DeviceMemory<double> *values); |
| Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<float>> *values); |
| Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<double>> *values); |
| Stream &ThenPopulateRandGaussian(float mean, float stddev, |
| DeviceMemory<float> *values); |
| Stream &ThenPopulateRandGaussian(double mean, double stddev, |
| DeviceMemory<double> *values); |
| |
| // Entrain onto the stream: a memcpy to a host destination from a GPU source |
| // of the given target size. host_dst must be a pointer to host memory |
| // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and |
| // then registered with StreamExecutor::HostMemoryRegister. |
| Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src, |
| uint64_t size); |
| |
| // Entrain onto the stream: a memcpy to a GPU destination from a host source |
| // of the given target size. host_src must be a pointer to host memory |
| // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and |
| // then registered with StreamExecutor::HostMemoryRegister. |
| Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src, |
| uint64_t size); |
| |
| // Alternative interface for memcpying from device to host that takes an |
| // array slice. Checks that the destination size can accommodate the host |
| // slice size. |
| template <typename T> |
| Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src, |
| absl::Span<T> host_dst) { |
| auto host_size = host_dst.size() * sizeof(T); |
| CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size()); |
| return ThenMemcpy(host_dst.begin(), gpu_src, host_size); |
| } |
| |
| // Alternative interface for memcpying from host to device that takes an |
| // array slice. Checks that the destination size can accommodate the host |
| // slice size. |
| template <typename T> |
| Stream &ThenMemcpyH2D(absl::Span<const T> host_src, |
| DeviceMemory<T> *gpu_dst) { |
| auto host_size = host_src.size() * sizeof(T); |
| CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size); |
| return ThenMemcpy(gpu_dst, host_src.begin(), host_size); |
| } |
| |
| // Entrain onto the stream: a memcpy to a GPU destination from a GPU source |
| // of the given target size. gpu_src/dst must be pointers to GPU memory and |
| // peer access must be enabled between their owning StreamExecutors. |
| Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, |
| uint64_t size); |
| |
| // Calls to the device-to-device copy overload of ThenMemcpy -- useful for |
| // ensuring that the host pointer isn't getting confused accidentally with a |
| // device pointer if you're not doing metaprogramming against the API. |
| Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst, |
| const DeviceMemoryBase &gpu_src, uint64_t size) { |
| return ThenMemcpy(gpu_dst, gpu_src, size); |
| } |
| |
| // Entrain onto the stream: a memset of zero at a GPU location of size bytes. |
| // The location must not be null. |
| Stream &ThenMemZero(DeviceMemoryBase *location, uint64_t size); |
| |
| // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location of |
| // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible |
| // by 4). The location must not be null. |
| Stream &ThenMemset32(DeviceMemoryBase *location, uint32 pattern, |
| uint64_t size); |
| |
| // Enqueue a forward operation of the RNN model onto the stream. |
| // See DnnSupport::DoRnnForward for more details. |
| Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, |
| const dnn::RnnSequenceTensorDescriptor &input_desc, |
| const DeviceMemory<Eigen::half> &input_data, |
| const DeviceMemory<int> &seq_lengths_data, |
| const dnn::RnnStateTensorDescriptor &input_h_desc, |
| const DeviceMemory<Eigen::half> &input_h_data, |
| const dnn::RnnStateTensorDescriptor &input_c_desc, |
| const DeviceMemory<Eigen::half> &input_c_data, |
| const DeviceMemory<Eigen::half> ¶ms, |
| const dnn::RnnSequenceTensorDescriptor &output_desc, |
| DeviceMemory<Eigen::half> *output_data, |
| const dnn::RnnStateTensorDescriptor &output_h_desc, |
| DeviceMemory<Eigen::half> *output_h_data, |
| const dnn::RnnStateTensorDescriptor &output_c_desc, |
| DeviceMemory<Eigen::half> *output_c_data, |
| bool is_training, |
| ScratchAllocator *reserve_space_allocator, |
| ScratchAllocator *workspace_allocator, |
| dnn::ProfileResult *output_profile_result); |
| |
| Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, |
| const dnn::RnnSequenceTensorDescriptor &input_desc, |
| const DeviceMemory<float> &input_data, |
| const DeviceMemory<int> &seq_lengths_data, |
| const dnn::RnnStateTensorDescriptor &input_h_desc, |
| const DeviceMemory<float> &input_h_data, |
| const dnn::RnnStateTensorDescriptor &input_c_desc, |
| const DeviceMemory<float> &input_c_data, |
| const DeviceMemory<float> ¶ms, |
| const dnn::RnnSequenceTensorDescriptor &output_desc, |
| DeviceMemory<float> *output_data, |
| const dnn::RnnStateTensorDescriptor &output_h_desc, |
| DeviceMemory<float> *output_h_data, |
| const dnn::RnnStateTensorDescriptor &output_c_desc, |
| DeviceMemory<float> *output_c_data, bool is_training, |
| ScratchAllocator *reserve_space_allocator, |
| ScratchAllocator *workspace_allocator, |
| dnn::ProfileResult *output_profile_result); |
| |
| Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, |
| const dnn::RnnSequenceTensorDescriptor &input_desc, |
| const DeviceMemory<double> &input_data, |
| const DeviceMemory<int> &seq_lengths_data, |
| const dnn::RnnStateTensorDescriptor &input_h_desc, |
| const DeviceMemory<double> &input_h_data, |
| const dnn::RnnStateTensorDescriptor &input_c_desc, |
| const DeviceMemory<double> &input_c_data, |
| const DeviceMemory<double> ¶ms, |
| const dnn::RnnSequenceTensorDescriptor &output_desc, |
| DeviceMemory<double> *output_data, |
| const dnn::RnnStateTensorDescriptor &output_h_desc, |
| DeviceMemory<double> *output_h_data, |
| const dnn::RnnStateTensorDescriptor &output_c_desc, |
| DeviceMemory<double> *output_c_data, bool is_training, |
| ScratchAllocator *reserve_space_allocator, |
| ScratchAllocator *workspace_allocator, |
| dnn::ProfileResult *output_profile_result); |
| |
| // Enqueue a backward operation of the RNN model onto the stream. |
| // See DnnSupport::DoRnnBackward for more details. |
| Stream &ThenRnnBackward( |
| const dnn::RnnDescriptor &rnn_desc, |
| const dnn::RnnSequenceTensorDescriptor &input_desc, |
| const DeviceMemory<Eigen::half> &input_data, |
| const DeviceMemory<int> &seq_lengths_data, |
| const dnn::RnnStateTensorDescriptor &input_h_desc, |
| const DeviceMemory<Eigen::half> &input_h_data, |
| const dnn::RnnStateTensorDescriptor &input_c_desc, |
| const DeviceMemory<Eigen::half> &input_c_data, |
| const DeviceMemory<Eigen::half> ¶ms, |
| const dnn::RnnSequenceTensorDescriptor &output_desc, |
| const DeviceMemory<Eigen::half> &output_data, |
| const dnn::RnnStateTensorDescriptor &output_h_desc, |
| const DeviceMemory<Eigen::half> &output_h_data, |
| const dnn::RnnStateTensorDescriptor &output_c_desc, |
| const DeviceMemory<Eigen::half> &output_c_data, |
| const DeviceMemory<Eigen::half> &output_backprop_data, |
| const DeviceMemory<Eigen::half> &output_h_backprop_data, |
| const DeviceMemory<Eigen::half> &output_c_backprop_data, |
| DeviceMemory<Eigen::half> *input_backprop_data, |
| DeviceMemory<Eigen::half> *input_h_backprop_data, |
| DeviceMemory<Eigen::half> *input_c_backprop_data, |
| DeviceMemory<Eigen::half> *params_backprop_data, |
| DeviceMemory<uint8> *reserve_space_data, |
| ScratchAllocator *workspace_allocator, |
| dnn::ProfileResult *output_profile_result); |
| |
| Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc, |
| const dnn::RnnSequenceTensorDescriptor &input_desc, |
| const DeviceMemory<float> &input_data, |
| const DeviceMemory<int> &seq_lengths_data, |
| const dnn::RnnStateTensorDescriptor &input_h_desc, |
| const DeviceMemory<float> &input_h_data, |
| const dnn::RnnStateTensorDescriptor &input_c_desc, |
| const DeviceMemory<float> &input_c_data, |
| const DeviceMemory<float> ¶ms, |
| const dnn::RnnSequenceTensorDescriptor &output_desc, |
| const DeviceMemory<float> &output_data, |
| const dnn::RnnStateTensorDescriptor &output_h_desc, |
| const DeviceMemory<float> &output_h_data, |
| const dnn::RnnStateTensorDescriptor &output_c_desc, |
| const DeviceMemory<float> &output_c_data, |
| const DeviceMemory<float> &output_backprop_data, |
| const DeviceMemory<float> &output_h_backprop_data, |
| const DeviceMemory<float> &output_c_backprop_data, |
| DeviceMemory<float> *input_backprop_data, |
| DeviceMemory<float> *input_h_backprop_data, |
| DeviceMemory<float> *input_c_backprop_data, |
| DeviceMemory<float> *params_backprop_data, |
| DeviceMemory<uint8> *reserve_space_data, |
| ScratchAllocator *workspace_allocator, |
| dnn::ProfileResult *output_profile_result); |
| |
| Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc, |
| const dnn::RnnSequenceTensorDescriptor &input_desc, |
| const DeviceMemory<double> &input_data, |
| const DeviceMemory<int> &seq_lengths_data, |
| const dnn::RnnStateTensorDescriptor &input_h_desc, |
| const DeviceMemory<double> &input_h_data, |
| const dnn::RnnStateTensorDescriptor &input_c_desc, |
| const DeviceMemory<double> &input_c_data, |
| const DeviceMemory<double> ¶ms, |
| const dnn::RnnSequenceTensorDescriptor &output_desc, |
| const DeviceMemory<double> &output_data, |
| const dnn::RnnStateTensorDescriptor &output_h_desc, |
| const DeviceMemory<double> &output_h_data, |
| const dnn::RnnStateTensorDescriptor &output_c_desc, |
| const DeviceMemory<double> &output_c_data, |
| const DeviceMemory<double> &output_backprop_data, |
| const DeviceMemory<double> &output_h_backprop_data, |
| const DeviceMemory<double> &output_c_backprop_data, |
| DeviceMemory<double> *input_backprop_data, |
| DeviceMemory<double> *input_h_backprop_data, |
| DeviceMemory<double> *input_c_backprop_data, |
| DeviceMemory<double> *params_backprop_data, |
| DeviceMemory<uint8> *reserve_space_data, |
| ScratchAllocator *workspace_allocator, |
| dnn::ProfileResult *output_profile_result); |
| |
| // Enqueue a CTCLoss operation onto the stream. |
| // See DnnSupport::DoCtcLoss for more details. |
| Stream &ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc, |
| const DeviceMemory<float> &probs_data, |
| absl::Span<const int> labels_data, |
| absl::Span<const int> labels_lengths_data, |
| absl::Span<const int> input_lengths_data, |
| DeviceMemory<float> *costs_data, |
| const dnn::RnnStateTensorDescriptor &grads_desc, |
| DeviceMemory<float> *grads_data, |
| ScratchAllocator *workspace_allocator); |
| |
| // Enqueue onto the stream a operation that transforms a tensor. |
| // See DnnSupport::DoTransformTensor for more details. |
| Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc, |
| dnn::DataType input_type, |
| const DeviceMemoryBase &input_data, |
| const dnn::BatchDescriptor &output_desc, |
| dnn::DataType output_type, float scale, |
| DeviceMemoryBase *output_data); |
| |
| // The templated version of the above ThenTransformTensor. Useful when the |
| // input and output types are statically known. |
| template <typename InElemT, typename OutElemT> |
| Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc, |
| const DeviceMemory<InElemT> &input_data, |
| const dnn::BatchDescriptor &output_desc, |
| DeviceMemory<OutElemT> *output_data) { |
| return ThenTransformTensor(input_desc, dnn::ToDataType<InElemT>(), |
| input_data, output_desc, |
| dnn::ToDataType<OutElemT>(), output_data); |
| } |
| |
| // (Synchronously) block the host code waiting for the operations |
| // entrained on the stream (enqueued to this point in program |
| // execution) to complete. |
| // |
| // Returns an OK status if the blocking was successful and the stream is ok(). |
| // Otherwise returns an error describing why the blocking failed. |
| port::Status BlockHostUntilDone() TF_LOCKS_EXCLUDED(mu_); |
| |
| // Warning! This method interacts with internal threads in |
| // sometimes-unpredictable ways and is intended for GPU-Executor-internal |
| // use |
| // only. Please check with a member of the FASTR team before making use of |
| // this method. |
| // |
| // Entrains onto the stream a function to be executed on the host at some |
| // point in the future. |
| // Async host callbacks DO NOT block the stream as device functions (or as |
| // synchronous host callbacks). No synchronization is possible with |
| // asynchronous callbacks; they are strictly fire-and-forget. |
| // This method is private due to the potential for undefined behavior with |
| // synchronization using OpenCL user events. |
| // The ONLY lifetime guarantee in these calls is that the StreamExecutor |
| // parameter will still be valid - this Stream may not be! |
| // Any callbacks requiring device API calls must use this method. |
| Stream &ThenEnqueueOnBackgroundThread( |
| std::function<void(StreamExecutor *)> task); |
| |
| // Returns the (opaque) platform-specific backing object. Ownership is not |
| // transferred to the caller. |
| internal::StreamInterface *implementation() { return implementation_.get(); } |
| |
| // Entrains onto the stream a callback to the host (from the device). |
| // Behaves as ThenDoHostCallbackWithStatus below, but the callback should |
| // never fail or its failure is inconsequential. |
| // |
| // This is kept for backward compatibility. Future code should use |
| // ThenDoHostCallbackWithStatus and explicitly return a success status. |
| // TODO(b/112125301): Eventually remove this method. |
| Stream &ThenDoHostCallback(std::function<void()> callback); |
| |
| // Entrains onto the stream a callback to the host (from the device). |
| // Host callbacks block/occupy the stream just as device functions |
| // (execute one at a time, block later stream operations). |
| // Whether the callback return status affects the result of BlockHostUntilDone |
| // is platform-dependent. |
| // |
| // Behavior is undefined when synchronizing using OpenCL user events. |
| // Behavior is undefined if host callbacks call device routines or insert |
| // them into any stream. |
| // |
| // On certain platforms, ThenDoHostCallback is expected to have significant |
| // negative effects on performance. |
| Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback); |
| |
| // Runs the given callback after the next call to BlockHostUntilDone on this |
| // stream (or after the Stream does BlockHostUntilDone in its destructor). |
| // This can act as a faster alternative to ThenDoHostCallbackWithStatus for |
| // some use cases. |
| Stream &ThenRunAfterNextBlockHostUntilDone(std::function<void()> callback); |
| |
| // Returns the StreamExecutor (parent object) associated with this stream. |
| StreamExecutor *parent() const { |
| CHECK(parent_ != nullptr); |
| return parent_; |
| } |
| |
| // |
| CudaComputeCapability GetCudaComputeCapability() const { |
| return parent()->GetDeviceDescription().cuda_compute_capability(); |
| } |
| |
| RocmComputeCapability GetRocmComputeCapability() const { |
| return parent()->GetDeviceDescription().rocm_compute_capability(); |
| } |
| // Returns the (internal usage) temporary-memory-allocation manager associated |
| // with this stream. |
| internal::TemporaryMemoryManager *temporary_memory_manager(); |
| |
| // Returns a debugging string "[stream=0x...,impl=0x...]". |
| std::string DebugStreamPointers() const; |
| |
| private: |
| friend class host::HostBlas; // for parent_. |
| friend class host::HostFft; // for parent_. |
| friend class host::HostRng; // for parent_. |
| template <typename... Args> |
| friend struct ThenBlasImpl; // for implementing ThenBlasXXX. |
| friend class ocl::CLBlas; // for parent_. |
| |
| // Checks whether types match before a call to extended BLAS version. |
| template <typename ABType, typename CType, typename ScaleType> |
| port::Status CheckTypesForExtendedBlas( |
| blas::ComputationType computation_type) { |
| static_assert( |
| detail::is_any_of<ABType, Eigen::half, Eigen::bfloat16, float, double, |
| int8_t, std::complex<float>, std::complex<double>>(), |
| "The only buffer types supported are: Eigen::half, float, " |
| "double, int8, std::complex<float> and std::complex<double>"); |
| static_assert( |
| std::is_same_v<ABType, CType> || |
| (std::is_same_v<ABType, int8_t> && std::is_same_v<CType, int32_t>), |
| "Input and output buffer types should be the same unless input is " |
| "int8 and output is int32"); |
| static_assert( |
| std::is_same_v<ScaleType, CType> || |
| (std::is_same_v<ScaleType, float> && |
| detail::is_any_of<CType, Eigen::half, Eigen::bfloat16>()), |
| "Mismatched alpha/beta and output types"); |
| |
| bool valid_computation_type = [computation_type] { |
| switch (computation_type) { |
| case blas::ComputationType::kF16: |
| return std::is_same_v<CType, Eigen::half>; |
| case blas::ComputationType::kF32: |
| return detail::is_any_of<CType, Eigen::half, Eigen::bfloat16, float, |
| std::complex<float>>(); |
| case blas::ComputationType::kF64: |
| return detail::is_any_of<CType, double, std::complex<double>>(); |
| case blas::ComputationType::kI32: |
| return std::is_same_v<CType, int32_t>; |
| case blas::ComputationType::kF16AsF32: // fall-through |
| case blas::ComputationType::kBF16AsF32: // fall-through |
| case blas::ComputationType::kTF32AsF32: |
| return detail::is_any_of<CType, float, std::complex<float>>(); |
| } |
| }(); |
| |
| if (!valid_computation_type) { |
| return port::InternalError(absl::StrCat( |
| "Invalid computation type ", |
| blas::ComputationTypeString(computation_type), " for output type: ", |
| blas::DataTypeString(blas::ToDataType<CType>::value))); |
| } |
| return ::tensorflow::OkStatus(); |
| } |
| |
| bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) { |
| absl::ReaderMutexLock lock(&mu_); |
| return !status_.ok(); |
| } |
| |
| // Sets the error state if operation_retcode is false. |
| // This is a useful shorthand for many stream routines. |
| void CheckError(bool operation_retcode) TF_LOCKS_EXCLUDED(mu_); |
| |
| // Checks the status and logs the error message, if any. |
| void CheckStatus(port::Status status) TF_LOCKS_EXCLUDED(mu_); |
| |
| void SetError() { CheckError(false /* = operation_retcode */); } |
| |
| void SetErrorAndLogNoDnnSupport() { |
| SetError(); |
| LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor " |
| "without DNN support"; |
| } |
| |
| // Runs the set of callbacks that are intended to run after |
| // BlockHostUntilDone. |
| void RunAfterBlockHostUntilDoneCallbacks(); |
| |
| // The StreamExecutor that supports the operation of this stream. |
| StreamExecutor *parent_; |
| |
| // The platform-dependent implementation that the StreamExecutor interface |
| // delegates to. |
| std::unique_ptr<internal::StreamInterface> implementation_; |
| |
| // mutex that guards the allocation / error state flags. |
| // Mutable so that it can be obtained via const reader lock. |
| mutable absl::Mutex mu_; |
| |
| // Whether Init() was successfully called to allocate this stream on the |
| // underlying platform. It simply flips from 0 to 1 with a sanity check. |
| // See StreamExecutor::AllocateStream. |
| bool allocated_ ABSL_GUARDED_BY(mu_); |
| |
| // The last error (if any) of all method calls. |
| port::Status status_ ABSL_GUARDED_BY(mu_); |
| |
| // Sub-streams that are generated from this stream. Each element has a pointer |
| // to sub-stream and a boolean value indicating if this substream is ready to |
| // be reused. |
| std::vector<std::pair<std::unique_ptr<Stream>, bool>> sub_streams_ |
| ABSL_GUARDED_BY(mu_); |
| |
| // Streams can allocate temporary memories to help with work they enqueue |
| // (e.g. for scratch memory spaces). This member tracks those allocations and |
| // notes when they can be reclaimed -- reclamation is attempted when |
| // BlockHostUntilDone() is called. |
| internal::TemporaryMemoryManager temporary_memory_manager_; |
| |
| // Callbacks enqueued to be run after the next call to BlockHostUntilDone(). |
| std::vector<std::function<void()>> after_block_host_until_done_callbacks_ |
| ABSL_GUARDED_BY(mu_); |
| |
| // Non-extended BLAS interface requires alpha/beta to be floats when input |
| // type is Eigen::half. However, for consistency purposes it is convenient |
| // for the interface to accept Eigen::half. |
| template <typename T> |
| void UpcastHalfToFloat(void **alpha_ptr, void **beta_ptr, |
| float *alpha_storage, float *beta_storage) { |
| if (std::is_same<T, Eigen::half>::value) { |
| *alpha_storage = |
| static_cast<float>(*reinterpret_cast<Eigen::half *>(*alpha_ptr)); |
| *beta_storage = |
| static_cast<float>(*reinterpret_cast<Eigen::half *>(*beta_ptr)); |
| *alpha_ptr = alpha_storage; |
| *beta_ptr = beta_storage; |
| } else if (std::is_same<T, Eigen::bfloat16>::value) { |
| *alpha_storage = |
| static_cast<float>(*reinterpret_cast<Eigen::bfloat16 *>(*alpha_ptr)); |
| *beta_storage = |
| static_cast<float>(*reinterpret_cast<Eigen::bfloat16 *>(*beta_ptr)); |
| *alpha_ptr = alpha_storage; |
| *beta_ptr = beta_storage; |
| } |
| } |
| |
| SE_DISALLOW_COPY_AND_ASSIGN(Stream); |
| }; |
| |
| //////////// |
| // Inlines |
| |
| template <typename... Params, typename... Args> |
| inline port::Status Stream::ThenLaunch(ThreadDim thread_dims, |
| BlockDim block_dims, |
| const TypedKernel<Params...> &kernel, |
| Args... args) { |
| KernelInvocationChecker<std::tuple<Params...>, |
| std::tuple<Args...>>::CheckAllStaticAssert(); |
| |
| // This is the core that allows type-safe kernel launching. |
| // Since the platforms take kernel arguments as tuples of (void *, size), |
| // we pack the variadic parameters passed as ...args into the desired |
| // tuple form and pass that packed form to the StreamExecutor::Launch() |
| // implementation. |
| KernelArgsArray<sizeof...(args)> kernel_args; |
| kernel.PackParams(&kernel_args, args...); |
| TF_RETURN_IF_ERROR( |
| parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args)); |
| return ::tensorflow::OkStatus(); |
| } |
| |
| template <typename T> |
| inline port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>> |
| Stream::AllocateTemporaryArray(uint64_t element_count) { |
| return temporary_memory_manager_.AllocateArray<T>(element_count); |
| } |
| |
| inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() { |
| return &temporary_memory_manager_; |
| } |
| |
| template <> |
| struct Quantization<uint8> { |
| static constexpr dnn::QuantizedActivationMode kModeId = |
| dnn::QuantizedActivationMode::k8Bit; |
| }; |
| |
| template <> |
| struct Quantization<uint16> { |
| static constexpr dnn::QuantizedActivationMode kModeId = |
| dnn::QuantizedActivationMode::k16Bit; |
| }; |
| |
| template <> |
| struct Quantization<int32> { |
| static constexpr dnn::QuantizedActivationMode kModeId = |
| dnn::QuantizedActivationMode::k32Bit; |
| }; |
| |
| } // namespace stream_executor |
| |
| #endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_H_ |