| /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/stream_executor/stream.h" |
| |
| #include "absl/strings/str_cat.h" |
| #include "third_party/eigen3/Eigen/Core" |
| #include "tensorflow/stream_executor/blas.h" |
| #include "tensorflow/stream_executor/host_or_device_scalar.h" |
| #include "tensorflow/stream_executor/lib/stacktrace.h" |
| #include "tensorflow/stream_executor/platform.h" |
| #include "tensorflow/stream_executor/platform/logging.h" |
| #include "tensorflow/stream_executor/platform/port.h" |
| #include "tensorflow/stream_executor/rng.h" |
| #include "tensorflow/stream_executor/stream_executor_internal.h" |
| #include "tensorflow/stream_executor/stream_executor_pimpl.h" |
| |
| namespace stream_executor { |
| |
| namespace { |
| // Code to turn parameters to functions on stream into strings that |
| // will be VLOG'ed. We need overloads, instead of |
| // e.g. BatchDescriptorToVlogString(), as the code that calls these |
| // functions does not know what the type of the parameter is. |
| std::string ToVlogString(const dnn::BatchDescriptor &descriptor) { |
| return descriptor.ToShortString(); |
| } |
| |
| std::string ToVlogString(const dnn::FilterDescriptor &descriptor) { |
| return descriptor.ToShortString(); |
| } |
| |
| std::string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) { |
| return descriptor.ToShortString(); |
| } |
| |
| std::string ToVlogString(const dnn::PoolingDescriptor &descriptor) { |
| return descriptor.ToShortString(); |
| } |
| |
| std::string ToVlogString(const dnn::NormalizeDescriptor &descriptor) { |
| return descriptor.ToShortString(); |
| } |
| |
| std::string ToVlogString(dnn::ActivationMode mode) { |
| return dnn::ActivationModeString(mode); |
| } |
| |
| std::string ToVlogString(const dnn::AlgorithmConfig &algo_config) { |
| return algo_config.ToString(); |
| } |
| |
| std::string ToVlogString(dnn::ElementwiseOperation op) { |
| return dnn::ElementwiseOperationString(op); |
| } |
| |
| std::string ToVlogString(dnn::QuantizedActivationMode mode) { |
| return dnn::QuantizedActivationModeString(mode); |
| } |
| |
| std::string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); } |
| |
| std::string ToVlogString(blas::UpperLower ul) { |
| return blas::UpperLowerString(ul); |
| } |
| |
| std::string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); } |
| |
| std::string ToVlogString(blas::Side s) { return blas::SideString(s); } |
| |
| std::string ToVlogString(blas::ComputationType ty) { |
| return blas::ComputationTypeString(ty); |
| } |
| |
| std::string ToVlogString(const void *ptr) { |
| if (ptr == nullptr) { |
| return "null"; |
| } |
| |
| // StrCat does not convert pointers to text. |
| std::ostringstream out; |
| out << ptr; |
| return out.str(); |
| } |
| |
| template <class T> |
| std::string ToVlogString(const std::complex<T> &c) { |
| // StrCat does not convert std::complex to text. |
| std::ostringstream out; |
| out << c; |
| return out.str(); |
| } |
| |
| template <class T> |
| std::string ToVlogString(const std::function<T> &f) { |
| return f == nullptr ? "null" : "<non-null function>"; |
| } |
| |
| std::string ToVlogString(const DeviceMemoryBase &memory) { |
| return ToVlogString(memory.opaque()); |
| } |
| |
| std::string ToVlogString(const DeviceMemoryBase *memory) { |
| return memory == nullptr ? "null" : ToVlogString(*memory); |
| } |
| |
| std::string ToVlogString(const Eigen::half &h) { |
| return absl::StrCat(static_cast<float>(h)); |
| } |
| |
| std::string ToVlogString(int i) { return absl::StrCat(i); } |
| |
| std::string ToVlogString(uint32 i) { return absl::StrCat(i); } |
| |
| std::string ToVlogString(uint64 i) { return absl::StrCat(i); } |
| |
| std::string ToVlogString(int64 i) { return absl::StrCat(i); } |
| |
| std::string ToVlogString(float f) { return absl::StrCat(f); } |
| |
| std::string ToVlogString(double d) { return absl::StrCat(d); } |
| |
| template <typename T> |
| std::string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) { |
| if (memory_or_constant.is_pointer()) { |
| return ToVlogString(memory_or_constant.pointer()); |
| } |
| return ToVlogString(memory_or_constant.value()); |
| } |
| |
| template <class T> |
| std::string ToVlogString(port::ArraySlice<T> elements) { |
| std::string str = absl::StrCat( |
| ToVlogString(reinterpret_cast<const void *>(elements.data())), "[", |
| elements.size(), "]{"); |
| const char *separator = ""; |
| size_t max_to_show = std::numeric_limits<size_t>::max(); |
| if (!VLOG_IS_ON(2)) { |
| max_to_show = 5; |
| } else if (!VLOG_IS_ON(3)) { |
| max_to_show = 20; |
| } else if (!VLOG_IS_ON(11)) { |
| max_to_show = 1000; |
| } |
| for (size_t i = 0; i < elements.size(); ++i) { |
| if (i == max_to_show) { |
| str += ", ..."; |
| break; |
| } |
| absl::StrAppend(&str, separator, ToVlogString(elements[i])); |
| separator = ", "; |
| } |
| str += "}"; |
| return str; |
| } |
| |
| template <class T> |
| std::string ToVlogString(port::MutableArraySlice<T> elements) { |
| return ToVlogString(port::ArraySlice<T>(elements)); |
| } |
| |
| std::string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) { |
| switch (depth_to_space_layout) { |
| case dnn::DepthToSpaceLayout::DepthHeightWidth: |
| return "DepthToSpaceLayout::DepthHeightWidth"; |
| } |
| return "unknown DepthToSpaceLayout"; |
| } |
| |
| std::string ToVlogString(dnn::DataType data_type) { |
| switch (data_type) { |
| case dnn::DataType::kFloat: |
| return "dnn::DataType::kFloat"; |
| case dnn::DataType::kDouble: |
| return "dnn::DataType::kDouble"; |
| case dnn::DataType::kHalf: |
| return "dnn::DataType::kHalf"; |
| case dnn::DataType::kInt8: |
| return "dnn::DataType::kInt8"; |
| case dnn::DataType::kInt32: |
| return "dnn::DataType::kInt32"; |
| default: |
| return "unknown DataType"; |
| } |
| } |
| |
| // Used together with PARAM to VLOG calls made to the stream. Intended |
| // to be used like this: |
| // |
| // VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)}); |
| // |
| // where a and b are the parameters to MyFunction. |
| // |
| // See VLOG_CALL for a short-hand for this. This way of doing it saves |
| // a tremendous amount of boilerplate code given how many functions |
| // there are on Stream and how many parameters they each have. |
| std::string CallStr(const char *function_name, Stream *stream, |
| std::vector<std::pair<const char *, std::string>> params) { |
| // Do not call this function unless VLOG is on since just |
| // constructing all the strings in params is expensive. |
| CHECK(VLOG_IS_ON(1)); |
| |
| std::string str = absl::StrCat(stream->DebugStreamPointers(), |
| " Called Stream::", function_name, "("); |
| const char *separator = ""; |
| for (const auto ¶m : params) { |
| absl::StrAppend(&str, separator, param.first, "=", param.second); |
| separator = ", "; |
| } |
| absl::StrAppend(&str, ")"); |
| if (VLOG_IS_ON(10)) { |
| absl::StrAppend(&str, " ", port::CurrentStackTrace(), "\n"); |
| } |
| return str; |
| } |
| |
| // Use this macro to avoid having to type every parameter twice to log |
| // it with VLOG and CallStr. |
| #define PARAM(parameter) \ |
| { #parameter, ToVlogString(parameter) } |
| |
| // Use this macro to avoid having to type out the name of each |
| // function and to save some boilerplate. Intended to be used like this: |
| // |
| // VLOG_CALL(PARAM(a), PARAM(b)) |
| // |
| // This saves a tremendous amount of boilerplate compared to the alternative: |
| // |
| // VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a) |
| // << ", b=" << ToVlogString(b); |
| // |
| // Note here that most of the parameter names are not short and that |
| // most of the functions take many more than 2 parameters. |
| #define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__}) |
| |
| } // namespace |
| |
| Stream::Stream(StreamExecutor *parent) |
| : parent_(parent), |
| implementation_(parent->implementation()->GetStreamImplementation()), |
| allocated_(false), |
| status_(port::InternalError("Uninitialized stream")), |
| temporary_memory_manager_(this) { |
| VLOG_CALL(PARAM(parent)); |
| } |
| |
| Stream::Stream(StreamExecutor *parent, |
| internal::StreamInterface *implementation) |
| : parent_(parent), |
| implementation_(implementation), |
| allocated_(false), |
| status_(port::InternalError("Uninitialized stream")), |
| temporary_memory_manager_(this) { |
| VLOG_CALL(PARAM(parent), PARAM(implementation)); |
| } |
| |
| Stream::~Stream() { |
| VLOG_CALL(); |
| |
| // Ensure the stream is completed. |
| auto status = BlockHostUntilDone(); |
| if (!status.ok()) { |
| LOG(WARNING) << "Error blocking host until done in stream destructor: " |
| << status; |
| } |
| temporary_memory_manager_.ForceDeallocateAll(); |
| RunAfterBlockHostUntilDoneCallbacks(); |
| |
| if (allocated_) { |
| parent_->DeallocateStream(this); |
| } |
| } |
| |
| port::Status Stream::RefreshStatus() { |
| port::Status status = parent_->GetStatus(this); |
| // We should not put the stream in an error state, just because the GetStatus |
| // method is unimplemented. |
| if (status != port::Status(port::error::UNIMPLEMENTED, |
| "GetStatus is not supported on this executor.")) { |
| CheckStatus(status); |
| } |
| return status; |
| } |
| |
| Stream &Stream::Init() { |
| VLOG_CALL(); |
| |
| absl::MutexLock lock(&mu_); |
| CHECK_EQ(false, allocated_) |
| << "stream appears to already have been initialized"; |
| CHECK(!status_.ok()) << "stream should be in !ok() state pre-initialization"; |
| |
| if (parent_->AllocateStream(this)) { |
| // Successful initialization! |
| allocated_ = true; |
| status_ = port::Status::OK(); |
| } else { |
| LOG(ERROR) << "failed to allocate stream during initialization"; |
| } |
| |
| return *this; |
| } |
| |
| Stream &Stream::InitTimer(Timer *timer) { |
| VLOG_CALL(PARAM(timer)); |
| |
| CheckError(parent_->AllocateTimer(timer)); |
| return *this; |
| } |
| |
| Stream &Stream::InitWithTimer(Timer *timer) { |
| VLOG_CALL(PARAM(timer)); |
| |
| return Init().InitTimer(timer); |
| } |
| |
| Stream &Stream::ThenRecordEvent(Event *event) { |
| VLOG_CALL(PARAM(event)); |
| |
| port::Status status = parent_->RecordEvent(this, event); |
| if (!status.ok()) { |
| LOG(ERROR) << "Error recording event in stream: " << status.error_message() |
| << "; not marking stream as bad, as the Event object may be " |
| << "at fault. Monitor for further errors."; |
| } |
| |
| return *this; |
| } |
| |
| Stream &Stream::ThenBatchNormalizationForward( |
| const DeviceMemory<float> &x, const DeviceMemory<float> &scale, |
| const DeviceMemory<float> &offset, |
| const DeviceMemory<float> &estimated_mean, |
| const DeviceMemory<float> &estimated_variance, |
| const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc, |
| const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, |
| const double exponential_average_factor, |
| dnn::ActivationMode activation_mode, DeviceMemory<float> *y, |
| DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var, |
| DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var, |
| bool is_training, |
| ScratchAllocator *reserve_space_allocator, |
| ScratchAllocator *workspace_allocator) { |
| VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc), |
| PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y)); |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoBatchNormalizationForward( |
| this, x, scale, offset, estimated_mean, estimated_variance, side_input, |
| x_desc, scale_offset_desc, epsilon, exponential_average_factor, |
| activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var, |
| is_training, reserve_space_allocator, workspace_allocator)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenBatchNormalizationBackward( |
| const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x, |
| const DeviceMemory<float> &scale, const DeviceMemory<float> &mean, |
| const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc, |
| const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, |
| DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop, |
| DeviceMemory<float> *offset_backprop, |
| DeviceMemory<uint8> *reserve_space_data, |
| ScratchAllocator *workspace_allocator) { |
| VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc), |
| PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop), |
| PARAM(scale_backprop), PARAM(offset_backprop)); |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoBatchNormalizationBackward( |
| this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc, |
| epsilon, x_backprop, scale_backprop, offset_backprop, |
| reserve_space_data, workspace_allocator)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenBatchNormalizationForward( |
| const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale, |
| const DeviceMemory<float> &offset, |
| const DeviceMemory<float> &estimated_mean, |
| const DeviceMemory<float> &estimated_variance, |
| const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc, |
| const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, |
| const double exponential_average_factor, |
| dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y, |
| DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var, |
| DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var, |
| bool is_training, |
| ScratchAllocator *reserve_space_allocator, |
| ScratchAllocator *workspace_allocator) { |
| VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc), |
| PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y)); |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoBatchNormalizationForward( |
| this, x, scale, offset, estimated_mean, estimated_variance, side_input, |
| x_desc, scale_offset_desc, epsilon, exponential_average_factor, |
| activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var, |
| is_training, reserve_space_allocator, workspace_allocator)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenBatchNormalizationBackward( |
| const DeviceMemory<Eigen::half> &y_backprop, |
| const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale, |
| const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var, |
| const dnn::BatchDescriptor &x_desc, |
| const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, |
| DeviceMemory<Eigen::half> *x_backprop, DeviceMemory<float> *scale_backprop, |
| DeviceMemory<float> *offset_backprop, |
| DeviceMemory<uint8> *reserve_space_data, |
| ScratchAllocator *workspace_allocator) { |
| VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc), |
| PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop), |
| PARAM(scale_backprop), PARAM(offset_backprop)); |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoBatchNormalizationBackward( |
| this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc, |
| epsilon, x_backprop, scale_backprop, offset_backprop, |
| reserve_space_data, workspace_allocator)); |
| |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| port::Status Stream::FusedConvolveWithAlgorithm( |
| const dnn::BatchDescriptor &conv_input_descriptor, |
| const DeviceMemory<double> &conv_input_data, double conv_input_scale, |
| const dnn::FilterDescriptor &filter_descriptor, |
| const DeviceMemory<double> &filter_data, |
| const dnn::ConvolutionDescriptor &convolution_descriptor, |
| const DeviceMemory<double> &side_input_data, double side_input_scale, |
| const dnn::BatchDescriptor &bias_descriptor, |
| const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode, |
| const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output, |
| ScratchAllocator *scratch_allocator, |
| const dnn::AlgorithmConfig &algorithm_config, |
| dnn::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data), |
| PARAM(conv_input_scale), PARAM(filter_descriptor), |
| PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases), |
| PARAM(side_input_data), PARAM(side_input_scale), |
| PARAM(activation_mode), PARAM(output_descriptor), PARAM(output), |
| PARAM(algorithm_config)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| return dnn->DoFusedConvolve( |
| this, conv_input_descriptor, conv_input_data, conv_input_scale, |
| filter_descriptor, filter_data, convolution_descriptor, side_input_data, |
| side_input_scale, bias_descriptor, biases, activation_mode, |
| output_descriptor, output, scratch_allocator, algorithm_config, |
| output_profile_result); |
| } |
| return port::UnimplementedError("DNN library is not found."); |
| } |
| |
| port::Status Stream::FusedConvolveWithAlgorithm( |
| const dnn::BatchDescriptor &conv_input_descriptor, |
| const DeviceMemory<float> &conv_input_data, float conv_input_scale, |
| const dnn::FilterDescriptor &filter_descriptor, |
| const DeviceMemory<float> &filter_data, |
| const dnn::ConvolutionDescriptor &convolution_descriptor, |
| const DeviceMemory<float> &side_input_data, float side_input_scale, |
| const dnn::BatchDescriptor &bias_descriptor, |
| const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode, |
| const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output, |
| ScratchAllocator *scratch_allocator, |
| const dnn::AlgorithmConfig &algorithm_config, |
| dnn::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data), |
| PARAM(conv_input_scale), PARAM(filter_descriptor), |
| PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases), |
| PARAM(side_input_data), PARAM(side_input_scale), |
| PARAM(activation_mode), PARAM(output_descriptor), PARAM(output), |
| PARAM(algorithm_config)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| return dnn->DoFusedConvolve( |
| this, conv_input_descriptor, conv_input_data, conv_input_scale, |
| filter_descriptor, filter_data, convolution_descriptor, side_input_data, |
| side_input_scale, bias_descriptor, biases, activation_mode, |
| output_descriptor, output, scratch_allocator, algorithm_config, |
| output_profile_result); |
| } |
| return port::UnimplementedError("DNN library is not found."); |
| } |
| |
| port::Status Stream::FusedConvolveWithAlgorithm( |
| const dnn::BatchDescriptor &conv_input_descriptor, |
| const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale, |
| const dnn::FilterDescriptor &filter_descriptor, |
| const DeviceMemory<Eigen::half> &filter_data, |
| const dnn::ConvolutionDescriptor &convolution_descriptor, |
| const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale, |
| const dnn::BatchDescriptor &bias_descriptor, |
| const DeviceMemory<Eigen::half> &biases, |
| dnn::ActivationMode activation_mode, |
| const dnn::BatchDescriptor &output_descriptor, |
| DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator, |
| const dnn::AlgorithmConfig &algorithm_config, |
| dnn::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data), |
| PARAM(conv_input_scale), PARAM(filter_descriptor), |
| PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases), |
| PARAM(side_input_data), PARAM(side_input_scale), |
| PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode), |
| PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| return dnn->DoFusedConvolve( |
| this, conv_input_descriptor, conv_input_data, conv_input_scale, |
| filter_descriptor, filter_data, convolution_descriptor, side_input_data, |
| side_input_scale, bias_descriptor, biases, activation_mode, |
| output_descriptor, output, scratch_allocator, algorithm_config, |
| output_profile_result); |
| } |
| return port::UnimplementedError("DNN library is not found."); |
| } |
| |
| port::Status Stream::FusedConvolveWithAlgorithm( |
| const dnn::BatchDescriptor &conv_input_descriptor, |
| const DeviceMemory<int8> &conv_input_data, float conv_input_scale, |
| const dnn::FilterDescriptor &filter_descriptor, |
| const DeviceMemory<int8> &filter_data, |
| const dnn::ConvolutionDescriptor &convolution_descriptor, |
| const DeviceMemory<int8> &side_input_data, float side_input_scale, |
| const dnn::BatchDescriptor &bias_descriptor, |
| const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode, |
| const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output, |
| ScratchAllocator *scratch_allocator, |
| const dnn::AlgorithmConfig &algorithm_config, |
| dnn::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data), |
| PARAM(conv_input_scale), PARAM(filter_descriptor), |
| PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases), |
| PARAM(side_input_data), PARAM(side_input_scale), |
| PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode), |
| PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| return dnn->DoFusedConvolve( |
| this, conv_input_descriptor, conv_input_data, conv_input_scale, |
| filter_descriptor, filter_data, convolution_descriptor, side_input_data, |
| side_input_scale, bias_descriptor, biases, activation_mode, |
| output_descriptor, output, scratch_allocator, algorithm_config, |
| output_profile_result); |
| } |
| return port::UnimplementedError("DNN library is not found."); |
| } |
| |
| port::Status Stream::FusedConvolveWithAlgorithm( |
| const dnn::BatchDescriptor &conv_input_descriptor, |
| const DeviceMemory<int8> &conv_input_data, float conv_input_scale, |
| const dnn::FilterDescriptor &filter_descriptor, |
| const DeviceMemory<int8> &filter_data, |
| const dnn::ConvolutionDescriptor &convolution_descriptor, |
| const DeviceMemory<float> &side_input_data, float side_input_scale, |
| const dnn::BatchDescriptor &bias_descriptor, |
| const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode, |
| const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output, |
| ScratchAllocator *scratch_allocator, |
| const dnn::AlgorithmConfig &algorithm_config, |
| dnn::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data), |
| PARAM(conv_input_scale), PARAM(filter_descriptor), |
| PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases), |
| PARAM(side_input_data), PARAM(side_input_scale), |
| PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode), |
| PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| return dnn->DoFusedConvolve( |
| this, conv_input_descriptor, conv_input_data, conv_input_scale, |
| filter_descriptor, filter_data, convolution_descriptor, side_input_data, |
| side_input_scale, bias_descriptor, biases, activation_mode, |
| output_descriptor, output, scratch_allocator, algorithm_config, |
| output_profile_result); |
| } |
| return port::UnimplementedError("DNN library is not found."); |
| } |
| |
| Stream &Stream::ThenConvolve( |
| const dnn::BatchDescriptor &input_descriptor, |
| const DeviceMemory<float> &input_data, |
| const dnn::FilterDescriptor &filter_descriptor, |
| const DeviceMemory<float> &filter_data, |
| const dnn::ConvolutionDescriptor &convolution_descriptor, |
| const dnn::BatchDescriptor &output_descriptor, |
| DeviceMemory<float> *output) { |
| if (ok()) { |
| CheckError(ConvolveWithAlgorithm( |
| input_descriptor, input_data, filter_descriptor, filter_data, |
| convolution_descriptor, output_descriptor, output, |
| /*scratch_allocator=*/nullptr, dnn::AlgorithmConfig(), |
| /*output_profile_result=*/nullptr) |
| .ok()); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenConvolveQuantized( |
| const dnn::BatchDescriptor &input_descriptor, |
| const DeviceMemory<float> &input_data, |
| const dnn::FilterDescriptor &filter_descriptor, |
| const DeviceMemory<int8> &filter_coefficients, |
| const DeviceMemory<float> &coefficient_scales, |
| const dnn::ConvolutionDescriptor &convolution_descriptor, |
| const dnn::BatchDescriptor &output_descriptor, |
| DeviceMemory<float> *output) { |
| VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), |
| PARAM(filter_descriptor), PARAM(filter_coefficients), |
| PARAM(coefficient_scales), PARAM(convolution_descriptor), |
| PARAM(output_descriptor), PARAM(output)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoConvolveQuantized( |
| this, input_descriptor, input_data, filter_descriptor, |
| filter_coefficients, coefficient_scales, convolution_descriptor, |
| output_descriptor, output)); |
| } else { |
| SetError(); |
| LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor " |
| "without DNN support"; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenConvolveQuantized( |
| const dnn::BatchDescriptor &input_descriptor, |
| const DeviceMemory<float> &input_data, |
| const dnn::FilterDescriptor &filter_descriptor, |
| const DeviceMemory<int16> &filter_coefficients, |
| const DeviceMemory<float> &coefficient_scales, |
| const dnn::ConvolutionDescriptor &convolution_descriptor, |
| const dnn::BatchDescriptor &output_descriptor, |
| DeviceMemory<float> *output) { |
| VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), |
| PARAM(filter_descriptor), PARAM(filter_coefficients), |
| PARAM(coefficient_scales), PARAM(convolution_descriptor), |
| PARAM(output_descriptor), PARAM(output)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoConvolveQuantized( |
| this, input_descriptor, input_data, filter_descriptor, |
| filter_coefficients, coefficient_scales, convolution_descriptor, |
| output_descriptor, output)); |
| } else { |
| SetError(); |
| LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor " |
| "without DNN support"; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenSeparableConvolve( |
| const dnn::BatchDescriptor &batch_descriptor, |
| const DeviceMemory<float> &input_data, |
| const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier, |
| const DeviceMemory<float> &first_weights, |
| const DeviceMemory<float> &second_weights, |
| const dnn::ConvolutionDescriptor &convolution_descriptor, |
| const dnn::BatchDescriptor &output_descriptor, |
| DeviceMemory<float> *output) { |
| VLOG_CALL( |
| PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor), |
| PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights), |
| PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoSeparableConvolve( |
| this, batch_descriptor, input_data, filter_descriptor, depth_multiplier, |
| first_weights, second_weights, convolution_descriptor, |
| output_descriptor, output)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| template <typename T> |
| Stream &Stream::ThenConvolveBackwardBiasImpl( |
| const dnn::BatchDescriptor &input_descriptor, |
| const DeviceMemory<T> &input_data, |
| const dnn::BatchDescriptor &bias_descriptor, |
| DeviceMemory<T> *backward_bias_data) { |
| VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(bias_descriptor), |
| PARAM(backward_bias_data)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoConvolveBackwardBias(this, input_descriptor, input_data, |
| bias_descriptor, |
| backward_bias_data)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenConvolveBackwardBias( |
| const dnn::BatchDescriptor &input_descriptor, |
| const DeviceMemory<double> &input_data, |
| const dnn::BatchDescriptor &bias_descriptor, |
| DeviceMemory<double> *backward_bias_data) { |
| return ThenConvolveBackwardBiasImpl(input_descriptor, input_data, |
| bias_descriptor, backward_bias_data); |
| } |
| |
| Stream &Stream::ThenConvolveBackwardBias( |
| const dnn::BatchDescriptor &input_descriptor, |
| const DeviceMemory<float> &input_data, |
| const dnn::BatchDescriptor &bias_descriptor, |
| DeviceMemory<float> *backward_bias_data) { |
| return ThenConvolveBackwardBiasImpl(input_descriptor, input_data, |
| bias_descriptor, backward_bias_data); |
| } |
| |
| Stream &Stream::ThenConvolveBackwardBias( |
| const dnn::BatchDescriptor &input_descriptor, |
| const DeviceMemory<Eigen::half> &input_data, |
| const dnn::BatchDescriptor &bias_descriptor, |
| DeviceMemory<Eigen::half> *backward_bias_data) { |
| return ThenConvolveBackwardBiasImpl(input_descriptor, input_data, |
| bias_descriptor, backward_bias_data); |
| } |
| |
| Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data, |
| const DeviceMemory<float> &weights, |
| const dnn::BatchDescriptor &input_dimensions, |
| const dnn::BatchDescriptor &output_dimensions, |
| DeviceMemory<float> *output_data) { |
| VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions), |
| PARAM(output_dimensions), PARAM(output_data)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions, |
| output_dimensions, output_data)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenMatMulQuantized( |
| const DeviceMemory<float> &input_data, const DeviceMemory<int8> &weights, |
| const DeviceMemory<float> &weight_scales, |
| const dnn::BatchDescriptor &input_dimensions, |
| const dnn::BatchDescriptor &output_dimensions, |
| DeviceMemory<float> *output_data) { |
| VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales), |
| PARAM(input_dimensions), PARAM(output_dimensions), |
| PARAM(output_data)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoMatMulQuantized(this, input_data, weights, weight_scales, |
| input_dimensions, output_dimensions, |
| output_data)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenMatMulQuantized( |
| const DeviceMemory<float> &input_data, const DeviceMemory<int16> &weights, |
| const DeviceMemory<float> &weight_scales, |
| const dnn::BatchDescriptor &input_dimensions, |
| const dnn::BatchDescriptor &output_dimensions, |
| DeviceMemory<float> *output_data) { |
| VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales), |
| PARAM(input_dimensions), PARAM(output_dimensions), |
| PARAM(output_data)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoMatMulQuantized(this, input_data, weights, weight_scales, |
| input_dimensions, output_dimensions, |
| output_data)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data, |
| const DeviceMemory<float> &biases, |
| const dnn::BatchDescriptor &dimensions, |
| DeviceMemory<float> *output_data) { |
| VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions), |
| PARAM(output_data)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError( |
| dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenPoolForward( |
| const dnn::PoolingDescriptor &pooling_dimensions, |
| const dnn::BatchDescriptor &input_dimensions, |
| const DeviceMemory<double> &input_data, |
| const dnn::BatchDescriptor &output_dimensions, |
| DeviceMemory<double> *output_data, ScratchAllocator *workspace_allocator) { |
| VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), |
| PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), |
| PARAM(workspace_allocator)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions, |
| input_data, output_dimensions, output_data, |
| workspace_allocator)); |
| } else { |
| SetError(); |
| LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor " |
| "without DNN support"; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenPoolForward( |
| const dnn::PoolingDescriptor &pooling_dimensions, |
| const dnn::BatchDescriptor &input_dimensions, |
| const DeviceMemory<float> &input_data, |
| const dnn::BatchDescriptor &output_dimensions, |
| DeviceMemory<float> *output_data, ScratchAllocator *workspace_allocator) { |
| VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), |
| PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), |
| PARAM(workspace_allocator)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions, |
| input_data, output_dimensions, output_data, |
| workspace_allocator)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenPoolForward( |
| const dnn::PoolingDescriptor &pooling_dimensions, |
| const dnn::BatchDescriptor &input_dimensions, |
| const DeviceMemory<Eigen::half> &input_data, |
| const dnn::BatchDescriptor &output_dimensions, |
| DeviceMemory<Eigen::half> *output_data, |
| ScratchAllocator *workspace_allocator) { |
| VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), |
| PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), |
| PARAM(workspace_allocator)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions, |
| input_data, output_dimensions, output_data, |
| workspace_allocator)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenPoolForward( |
| const dnn::PoolingDescriptor &pooling_dimensions, |
| const dnn::BatchDescriptor &input_dimensions, |
| const DeviceMemory<int8> &input_data, |
| const dnn::BatchDescriptor &output_dimensions, |
| DeviceMemory<int8> *output_data, ScratchAllocator *workspace_allocator) { |
| VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), |
| PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), |
| PARAM(workspace_allocator)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions, |
| input_data, output_dimensions, output_data, |
| workspace_allocator)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenPoolBackward( |
| const dnn::PoolingDescriptor &pooling_dimensions, |
| const dnn::BatchDescriptor &input_dimensions, |
| const DeviceMemory<double> &input_data, |
| const dnn::BatchDescriptor &output_dimensions, |
| const DeviceMemory<double> &output_data, |
| const DeviceMemory<double> &input_diff_data, |
| DeviceMemory<double> *output_diff_data, |
| ScratchAllocator *workspace_allocator) { |
| VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), |
| PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), |
| PARAM(input_diff_data), PARAM(output_diff_data), |
| PARAM(workspace_allocator)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions, |
| input_data, output_dimensions, output_data, |
| input_diff_data, output_diff_data, |
| workspace_allocator)); |
| } else { |
| SetError(); |
| LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor " |
| "without DNN support"; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenPoolBackward( |
| const dnn::PoolingDescriptor &pooling_dimensions, |
| const dnn::BatchDescriptor &input_dimensions, |
| const DeviceMemory<float> &input_data, |
| const dnn::BatchDescriptor &output_dimensions, |
| const DeviceMemory<float> &output_data, |
| const DeviceMemory<float> &input_diff_data, |
| DeviceMemory<float> *output_diff_data, |
| ScratchAllocator *workspace_allocator) { |
| VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), |
| PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), |
| PARAM(input_diff_data), PARAM(output_diff_data), |
| PARAM(workspace_allocator)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions, |
| input_data, output_dimensions, output_data, |
| input_diff_data, output_diff_data, |
| workspace_allocator)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenPoolBackward( |
| const dnn::PoolingDescriptor &pooling_dimensions, |
| const dnn::BatchDescriptor &input_dimensions, |
| const DeviceMemory<Eigen::half> &input_data, |
| const dnn::BatchDescriptor &output_dimensions, |
| const DeviceMemory<Eigen::half> &output_data, |
| const DeviceMemory<Eigen::half> &input_diff_data, |
| DeviceMemory<Eigen::half> *output_diff_data, |
| ScratchAllocator *workspace_allocator) { |
| VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), |
| PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), |
| PARAM(input_diff_data), PARAM(output_diff_data), |
| PARAM(workspace_allocator)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions, |
| input_data, output_dimensions, output_data, |
| input_diff_data, output_diff_data, |
| workspace_allocator)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenNormalizeWithDimensions( |
| const dnn::NormalizeDescriptor &normalize_descriptor, |
| const dnn::BatchDescriptor &dimensions, |
| const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) { |
| VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data), |
| PARAM(output_data)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoNormalizeWithDimensions( |
| this, normalize_descriptor, dimensions, input_data, output_data)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenNormalizeBackwardWithDimensions( |
| const dnn::NormalizeDescriptor &normalize_descriptor, |
| const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data, |
| const DeviceMemory<float> &normalized_data, |
| const DeviceMemory<float> &normalized_variable_gradient, |
| DeviceMemory<float> *raw_variable_gradient, |
| ScratchAllocator *workspace_allocator) { |
| VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data), |
| PARAM(normalized_data), PARAM(normalized_variable_gradient), |
| PARAM(raw_variable_gradient), PARAM(workspace_allocator)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoNormalizeBackwardWithDimensions( |
| this, normalize_descriptor, dimensions, raw_data, normalized_data, |
| normalized_variable_gradient, raw_variable_gradient, |
| workspace_allocator)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode, |
| const dnn::BatchDescriptor &dimensions, |
| const DeviceMemory<float> &input_data, |
| DeviceMemory<float> *output_data) { |
| return ThenActivateWithOptions(activation_mode, dimensions, input_data, |
| output_data, /*options=*/0); |
| } |
| |
| Stream &Stream::ThenActivateWithOptions(dnn::ActivationMode activation_mode, |
| const dnn::BatchDescriptor &dimensions, |
| const DeviceMemory<float> &input_data, |
| DeviceMemory<float> *output_data, |
| uint64 options) { |
| VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data), |
| PARAM(output_data), PARAM(options)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data, |
| output_data, options)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenDepthConcatenate( |
| port::ArraySlice<dnn::BatchDescriptor> input_dimensions, |
| port::ArraySlice<const DeviceMemory<float> *> input_data, |
| DeviceMemory<float> *output_data) { |
| VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data)); |
| |
| for (size_t i = 1; i < input_dimensions.size(); ++i) { |
| if (input_dimensions[i].count() != input_dimensions[0].count() || |
| input_dimensions[i].height() != input_dimensions[0].height() || |
| input_dimensions[i].width() != input_dimensions[0].width()) { |
| SetError(); |
| LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n" |
| << "input_dimensions[0]: " << input_dimensions[0].ToString() |
| << "input_dimensions[" << i |
| << "]: " << input_dimensions[i].ToString(); |
| return *this; |
| } |
| } |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data, |
| output_data)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenSpaceConcatenate( |
| port::ArraySlice<dnn::BatchDescriptor> input_dimensions, |
| port::ArraySlice<const DeviceMemory<float> *> input_data, |
| DeviceMemory<float> *output_data, |
| dnn::SpaceConcatenateMode concat_direction) { |
| VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data)); |
| |
| // Check that the input dimensions of all the other batches match those of the |
| // first batch. |
| for (size_t i = 1; i < input_dimensions.size(); ++i) { |
| if ((concat_direction == dnn::SpaceConcatenateMode::XDirection) && |
| (input_dimensions[i].count() != input_dimensions[0].count() || |
| input_dimensions[i].height() != input_dimensions[0].height() || |
| input_dimensions[i].feature_map_count() != |
| input_dimensions[0].feature_map_count())) { |
| SetError(); |
| LOG(ERROR) << "Incompatible dimensions for X concatenation.\n" |
| << "input_dimensions[0]: " << input_dimensions[0].ToString() |
| << "input_dimensions[" << i |
| << "]: " << input_dimensions[i].ToString(); |
| return *this; |
| } |
| |
| if ((concat_direction == dnn::SpaceConcatenateMode::YDirection) && |
| (input_dimensions[i].count() != input_dimensions[0].count() || |
| input_dimensions[i].width() != input_dimensions[0].width() || |
| input_dimensions[i].feature_map_count() != |
| input_dimensions[0].feature_map_count())) { |
| SetError(); |
| LOG(ERROR) << "Incompatible dimensions for Y concatenation.\n" |
| << "input_dimensions[0]: " << input_dimensions[0].ToString() |
| << "input_dimensions[" << i |
| << "]: " << input_dimensions[i].ToString(); |
| return *this; |
| } |
| } |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoSpaceConcatenate(this, input_dimensions, input_data, |
| output_data, concat_direction)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenReshape(const dnn::BatchDescriptor &input_dimensions, |
| const DeviceMemory<float> &input_data, |
| const dnn::BatchDescriptor &output_dimensions, |
| DeviceMemory<float> *output_data) { |
| VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), |
| PARAM(output_dimensions), PARAM(output_data)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoReshape(this, input_dimensions, input_data, |
| output_dimensions, output_data)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenDepthToSpace( |
| const dnn::BatchDescriptor &input_dimensions, |
| const DeviceMemory<float> &input_data, |
| const dnn::DepthToSpaceLayout &depth_to_space_layout, |
| const int sqrt_depth_reduction, DeviceMemory<float> *output_data) { |
| VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), |
| PARAM(depth_to_space_layout), PARAM(sqrt_depth_reduction), |
| PARAM(output_data)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoDepthToSpace(this, input_dimensions, input_data, |
| depth_to_space_layout, sqrt_depth_reduction, |
| output_data)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenSpaceToDepth( |
| const dnn::BatchDescriptor &input_dimensions, |
| const DeviceMemory<float> &input_data, |
| const dnn::DepthToSpaceLayout &space_to_depth_layout, |
| const int sqrt_depth_increase, DeviceMemory<float> *output_data) { |
| VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), |
| PARAM(space_to_depth_layout), PARAM(sqrt_depth_increase), |
| PARAM(output_data)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoSpaceToDepth(this, input_dimensions, input_data, |
| space_to_depth_layout, sqrt_depth_increase, |
| output_data)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenElementwiseOperate( |
| dnn::ElementwiseOperation operation, |
| port::ArraySlice<dnn::BatchDescriptor> input_dimensions, |
| port::ArraySlice<const DeviceMemory<float> *> input_data, |
| const dnn::BatchDescriptor &output_dimensions, |
| DeviceMemory<float> *output_data) { |
| VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data), |
| PARAM(output_dimensions), PARAM(output_data)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions, |
| input_data, output_dimensions, |
| output_data)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenElementwiseOperateScaledQuantized( |
| dnn::ElementwiseOperation operation, |
| port::ArraySlice<int> input_multiplicands, int output_divisor, |
| port::ArraySlice<dnn::BatchDescriptor> input_dimensions, |
| port::ArraySlice<const DeviceMemory<float> *> input_data, |
| const dnn::BatchDescriptor &output_dimensions, |
| DeviceMemory<float> *output_data) { |
| VLOG_CALL(PARAM(operation), PARAM(input_multiplicands), PARAM(output_divisor), |
| PARAM(input_dimensions), PARAM(input_data), |
| PARAM(output_dimensions), PARAM(output_data)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoElementwiseOperateScaledQuantized( |
| this, operation, input_multiplicands, output_divisor, input_dimensions, |
| input_data, output_dimensions, output_data)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions, |
| const DeviceMemory<float> &input_data, int64 left_pad, |
| int64 right_pad, int64 top_pad, int64 bottom_pad, |
| DeviceMemory<float> *output_data) { |
| VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad), |
| PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad), |
| PARAM(output_data)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad, |
| top_pad, bottom_pad, output_data)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions, |
| const DeviceMemory<float> &input_data, |
| int64 left_trim, int64 right_trim, int64 top_trim, |
| int64 bottom_trim, |
| DeviceMemory<float> *output_data) { |
| VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim), |
| PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim), |
| PARAM(output_data)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim, |
| right_trim, top_trim, bottom_trim, output_data)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenXYBroadcast(const dnn::BatchDescriptor &dimensions, |
| const DeviceMemory<float> &input_data, |
| int64 replicate_x, int64 replicate_y, |
| DeviceMemory<float> *output_data) { |
| VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(replicate_x), |
| PARAM(replicate_y), PARAM(output_data)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoXYBroadcast(this, dimensions, input_data, replicate_x, |
| replicate_y, output_data)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenMemcpyD2HQuantized( |
| const DeviceMemory<float> &gpu_unquantized_src, |
| dnn::QuantizedActivationMode mode, void *host_dst, uint64 size) { |
| VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst), |
| PARAM(size)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode, |
| host_dst, size)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenMemcpyH2DQuantized( |
| const void *host_src, uint64 size, dnn::QuantizedActivationMode mode, |
| DeviceMemory<float> *gpu_unquantized_dst) { |
| VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode), |
| PARAM(gpu_unquantized_dst)); |
| |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode, |
| gpu_unquantized_dst)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream *Stream::GetOrCreateSubStream() { |
| // Do not destroy bad streams when holding mu_ because ~Stream() may |
| // BlockHostUntilDone and it's host callbacks might attempt to acquire mu_. |
| std::vector<std::unique_ptr<Stream>> bad_streams; |
| |
| absl::MutexLock lock(&mu_); |
| |
| // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams |
| // we encounter along the way. |
| for (size_t index = 0; index < sub_streams_.size();) { |
| std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index]; |
| if (pair.second) { |
| // The sub_stream is reusable. |
| Stream *sub_stream = pair.first.get(); |
| if (sub_stream->ok()) { |
| VLOG(1) << DebugStreamPointers() << " reusing sub_stream " |
| << sub_stream->DebugStreamPointers(); |
| pair.second = false; |
| return sub_stream; |
| } |
| |
| // The stream is reusable and not ok. Streams have a monotonic state |
| // machine; the stream will remain in !ok forever. Swap it with the last |
| // stream and pop it off. |
| const int64 last = sub_streams_.size() - 1; |
| if (index != last) { |
| std::swap(pair, sub_streams_[last]); |
| } |
| bad_streams.push_back(std::move(sub_streams_.back().first)); |
| sub_streams_.pop_back(); |
| VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream " |
| << sub_stream->DebugStreamPointers(); |
| } else { |
| // The sub_stream is not reusable, move on to the next one. |
| ++index; |
| } |
| } |
| |
| // No streams are reusable; create a new stream. |
| sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}}, |
| false); |
| Stream *sub_stream = sub_streams_.back().first.get(); |
| sub_stream->Init(); |
| if (!sub_stream->ok()) { |
| LOG(ERROR) << "sub-stream failed to be initialized"; |
| } |
| VLOG(1) << DebugStreamPointers() << " created new sub_stream " |
| << sub_stream->DebugStreamPointers(); |
| |
| return sub_stream; |
| } |
| |
| void Stream::ReturnSubStream(Stream *sub_stream) { |
| // Do not destroy bad streams when holding mu_ because ~Stream() may |
| // BlockHostUntilDone and it's host callbacks might attempt to acquire mu_. |
| std::unique_ptr<Stream> bad_stream; |
| |
| absl::MutexLock lock(&mu_); |
| |
| // Look for the sub-stream. |
| for (int64 index = 0, end = sub_streams_.size(); index < end; ++index) { |
| std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index]; |
| if (pair.first.get() != sub_stream) { |
| continue; |
| } |
| |
| // Found the sub_stream. |
| if (sub_stream->ok()) { |
| VLOG(1) << DebugStreamPointers() << " returned ok sub_stream " |
| << sub_stream->DebugStreamPointers(); |
| pair.second = true; |
| } else { |
| // The returned stream is not ok. Streams have a monotonic state |
| // machine; the stream will remain in !ok forever. Swap it with the last |
| // stream and pop it off. |
| VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream " |
| << sub_stream->DebugStreamPointers(); |
| const int64 last = sub_streams_.size() - 1; |
| if (index != last) { |
| std::swap(pair, sub_streams_[last]); |
| } |
| std::swap(bad_stream, sub_streams_.back().first); |
| sub_streams_.pop_back(); |
| } |
| return; |
| } |
| |
| LOG(FATAL) << DebugStreamPointers() |
| << " did not create the returned sub-stream " |
| << sub_stream->DebugStreamPointers(); |
| } |
| |
| Stream &Stream::ThenStartTimer(Timer *t) { |
| VLOG_CALL(PARAM(t)); |
| |
| CheckError(parent_->StartTimer(this, t)); |
| return *this; |
| } |
| |
| Stream &Stream::ThenStopTimer(Timer *t) { |
| VLOG_CALL(PARAM(t)); |
| |
| CheckError(parent_->StopTimer(this, t)); |
| return *this; |
| } |
| |
| Stream &Stream::ThenWaitFor(Stream *other) { |
| VLOG_CALL(PARAM(other)); |
| |
| CHECK(this != other) << "stream cannot wait for itself"; |
| if (ok() && other->ok()) { |
| CheckError(parent_->CreateStreamDependency(this, other)); |
| } else { |
| SetError(); |
| LOG(INFO) << DebugStreamPointers() << " did not wait for " |
| << other->DebugStreamPointers(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenWaitFor(Event *event) { |
| VLOG_CALL(PARAM(event)); |
| |
| if (ok()) { |
| port::Status status = parent_->WaitForEvent(this, event); |
| if (!status.ok()) { |
| LOG(ERROR) << "Error waiting for event in stream: " |
| << status.error_message() |
| << "; not marking stream as bad, as the Event object may be " |
| << "at fault. Monitor for further errors."; |
| } |
| } else { |
| LOG(INFO) << DebugStreamPointers() << " did not wait for an event."; |
| } |
| return *this; |
| } |
| |
| // A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX |
| // functions and logs for errors. |
| template <typename... Args> |
| struct ThenBlasImpl { |
| // blas_func is the DoBlasXXX member function pointer, and args are its |
| // arguments except the first one of Stream* type. |
| Stream &operator()(Stream *stream, |
| bool (blas::BlasSupport::*blas_func)(Stream *, Args...), |
| Args... args) { |
| return Run(stream, blas_func, /*record_error=*/true, args...); |
| } |
| |
| // Like operator(), but only calls stream->CheckError() if record_error is |
| // true. |
| Stream &Run(Stream *stream, |
| bool (blas::BlasSupport::*blas_func)(Stream *, Args...), |
| bool record_error, Args... args); |
| }; |
| |
| template <typename... Args> |
| Stream &ThenBlasImpl<Args...>::Run( |
| Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...), |
| bool record_error, Args... args) { |
| if (stream->ok()) { |
| bool ok; |
| if (blas::BlasSupport *blas = stream->parent_->AsBlas()) { |
| ok = (blas->*blas_func)(stream, args...); |
| } else { |
| LOG(WARNING) |
| << "attempting to perform BLAS operation using StreamExecutor " |
| "without BLAS support"; |
| ok = false; |
| } |
| if (record_error) { |
| stream->CheckError(ok); |
| } |
| } |
| return *stream; |
| } |
| |
| Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x, |
| int incx, DeviceMemory<float> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx, |
| result); |
| } |
| |
| Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x, |
| int incx, DeviceMemory<double> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<double> &, int, |
| DeviceMemory<double> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx, |
| result); |
| } |
| |
| Stream &Stream::ThenBlasAsum(uint64 elem_count, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, DeviceMemory<float> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<float> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx, |
| result); |
| } |
| |
| Stream &Stream::ThenBlasAsum(uint64 elem_count, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, DeviceMemory<double> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<double> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx, |
| result); |
| } |
| |
| Stream &Stream::ThenBlasAxpy(uint64 elem_count, float alpha, |
| const DeviceMemory<float> &x, int incx, |
| DeviceMemory<float> *y, int incy) { |
| VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), |
| PARAM(incy)); |
| |
| ThenBlasImpl<uint64, float, const DeviceMemory<float> &, int, |
| DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx, |
| y, incy); |
| } |
| |
| Stream &Stream::ThenBlasAxpy(uint64 elem_count, double alpha, |
| const DeviceMemory<double> &x, int incx, |
| DeviceMemory<double> *y, int incy) { |
| VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), |
| PARAM(incy)); |
| |
| ThenBlasImpl<uint64, double, const DeviceMemory<double> &, int, |
| DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx, |
| y, incy); |
| } |
| |
| Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, DeviceMemory<std::complex<float>> *y, |
| int incy) { |
| VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), |
| PARAM(incy)); |
| |
| ThenBlasImpl<uint64, std::complex<float>, |
| const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx, |
| y, incy); |
| } |
| |
| Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, DeviceMemory<std::complex<double>> *y, |
| int incy) { |
| VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), |
| PARAM(incy)); |
| |
| ThenBlasImpl<uint64, std::complex<double>, |
| const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx, |
| y, incy); |
| } |
| |
| Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x, |
| int incx, DeviceMemory<float> *y, int incy) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *, |
| int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y, |
| incy); |
| } |
| |
| Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x, |
| int incx, DeviceMemory<double> *y, int incy) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<double> &, int, |
| DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y, |
| incy); |
| } |
| |
| Stream &Stream::ThenBlasCopy(uint64 elem_count, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, DeviceMemory<std::complex<float>> *y, |
| int incy) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y, |
| incy); |
| } |
| |
| Stream &Stream::ThenBlasCopy(uint64 elem_count, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, DeviceMemory<std::complex<double>> *y, |
| int incy) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y, |
| incy); |
| } |
| |
| Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x, |
| int incx, const DeviceMemory<float> &y, int incy, |
| DeviceMemory<float> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), |
| PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<float> &, int, |
| const DeviceMemory<float> &, int, DeviceMemory<float> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy, |
| result); |
| } |
| |
| Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x, |
| int incx, const DeviceMemory<double> &y, int incy, |
| DeviceMemory<double> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), |
| PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<double> &, int, |
| const DeviceMemory<double> &, int, DeviceMemory<double> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy, |
| result); |
| } |
| |
| Stream &Stream::ThenBlasDotc(uint64 elem_count, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, |
| const DeviceMemory<std::complex<float>> &y, |
| int incy, |
| DeviceMemory<std::complex<float>> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), |
| PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int, |
| const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<std::complex<float>> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y, |
| incy, result); |
| } |
| |
| Stream &Stream::ThenBlasDotc(uint64 elem_count, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, |
| const DeviceMemory<std::complex<double>> &y, |
| int incy, |
| DeviceMemory<std::complex<double>> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), |
| PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int, |
| const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<std::complex<double>> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y, |
| incy, result); |
| } |
| |
| Stream &Stream::ThenBlasDotu(uint64 elem_count, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, |
| const DeviceMemory<std::complex<float>> &y, |
| int incy, |
| DeviceMemory<std::complex<float>> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), |
| PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int, |
| const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<std::complex<float>> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y, |
| incy, result); |
| } |
| |
| Stream &Stream::ThenBlasDotu(uint64 elem_count, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, |
| const DeviceMemory<std::complex<double>> &y, |
| int incy, |
| DeviceMemory<std::complex<double>> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), |
| PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int, |
| const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<std::complex<double>> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y, |
| incy, result); |
| } |
| |
| Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x, |
| int incx, DeviceMemory<float> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx, |
| result); |
| } |
| |
| Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x, |
| int incx, DeviceMemory<double> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<double> &, int, |
| DeviceMemory<double> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx, |
| result); |
| } |
| |
| Stream &Stream::ThenBlasNrm2(uint64 elem_count, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, DeviceMemory<float> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<float> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx, |
| result); |
| } |
| |
| Stream &Stream::ThenBlasNrm2(uint64 elem_count, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, DeviceMemory<double> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<double> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx, |
| result); |
| } |
| |
| Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx, |
| DeviceMemory<float> *y, int incy, float c, |
| float s) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), |
| PARAM(c), PARAM(s)); |
| |
| ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int, |
| float, float> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy, |
| c, s); |
| } |
| |
| Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x, |
| int incx, DeviceMemory<double> *y, int incy, |
| double c, double s) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), |
| PARAM(c), PARAM(s)); |
| |
| ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int, |
| double, double> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy, |
| c, s); |
| } |
| |
| Stream &Stream::ThenBlasRot(uint64 elem_count, |
| DeviceMemory<std::complex<float>> *x, int incx, |
| DeviceMemory<std::complex<float>> *y, int incy, |
| float c, float s) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), |
| PARAM(c), PARAM(s)); |
| |
| ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int, |
| DeviceMemory<std::complex<float>> *, int, float, float> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy, |
| c, s); |
| } |
| |
| Stream &Stream::ThenBlasRot(uint64 elem_count, |
| DeviceMemory<std::complex<double>> *x, int incx, |
| DeviceMemory<std::complex<double>> *y, int incy, |
| double c, double s) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), |
| PARAM(c), PARAM(s)); |
| |
| ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int, |
| DeviceMemory<std::complex<double>> *, int, double, double> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy, |
| c, s); |
| } |
| |
| Stream &Stream::ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b, |
| DeviceMemory<float> *c, DeviceMemory<float> *s) { |
| VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s)); |
| |
| ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *, |
| DeviceMemory<float> *, DeviceMemory<float> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s); |
| } |
| |
| Stream &Stream::ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b, |
| DeviceMemory<double> *c, DeviceMemory<double> *s) { |
| VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s)); |
| |
| ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *, |
| DeviceMemory<double> *, DeviceMemory<double> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s); |
| } |
| |
| Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<float>> *a, |
| DeviceMemory<std::complex<float>> *b, |
| DeviceMemory<float> *c, |
| DeviceMemory<std::complex<float>> *s) { |
| VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s)); |
| |
| ThenBlasImpl<DeviceMemory<std::complex<float>> *, |
| DeviceMemory<std::complex<float>> *, DeviceMemory<float> *, |
| DeviceMemory<std::complex<float>> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s); |
| } |
| |
| Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<double>> *a, |
| DeviceMemory<std::complex<double>> *b, |
| DeviceMemory<double> *c, |
| DeviceMemory<std::complex<double>> *s) { |
| VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s)); |
| |
| ThenBlasImpl<DeviceMemory<std::complex<double>> *, |
| DeviceMemory<std::complex<double>> *, DeviceMemory<double> *, |
| DeviceMemory<std::complex<double>> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s); |
| } |
| |
| Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x, |
| int incx, DeviceMemory<float> *y, int incy, |
| const DeviceMemory<float> ¶m) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), |
| PARAM(param)); |
| |
| ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int, |
| const DeviceMemory<float> &> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y, |
| incy, param); |
| } |
| |
| Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x, |
| int incx, DeviceMemory<double> *y, int incy, |
| const DeviceMemory<double> ¶m) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), |
| PARAM(param)); |
| |
| ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int, |
| const DeviceMemory<double> &> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y, |
| incy, param); |
| } |
| |
| Stream &Stream::ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2, |
| DeviceMemory<float> *x1, |
| const DeviceMemory<float> &y1, |
| DeviceMemory<float> *param) { |
| VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param)); |
| |
| ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *, |
| DeviceMemory<float> *, const DeviceMemory<float> &, |
| DeviceMemory<float> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param); |
| } |
| |
| Stream &Stream::ThenBlasRotmg(DeviceMemory<double> *d1, |
| DeviceMemory<double> *d2, |
| DeviceMemory<double> *x1, |
| const DeviceMemory<double> &y1, |
| DeviceMemory<double> *param) { |
| VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param)); |
| |
| ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *, |
| DeviceMemory<double> *, const DeviceMemory<double> &, |
| DeviceMemory<double> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param); |
| } |
| |
| Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha, |
| DeviceMemory<float> *x, int incx) { |
| VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<uint64, float, DeviceMemory<float> *, int> impl; |
| return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha, |
| DeviceMemory<double> *x, int incx) { |
| VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<uint64, double, DeviceMemory<double> *, int> impl; |
| return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha, |
| DeviceMemory<std::complex<float>> *x, int incx) { |
| VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<uint64, float, DeviceMemory<std::complex<float>> *, int> impl; |
| return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha, |
| DeviceMemory<std::complex<double>> *x, int incx) { |
| VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<uint64, double, DeviceMemory<std::complex<double>> *, int> impl; |
| return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<float> alpha, |
| DeviceMemory<std::complex<float>> *x, int incx) { |
| VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<uint64, std::complex<float>, DeviceMemory<std::complex<float>> *, |
| int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<double> alpha, |
| DeviceMemory<std::complex<double>> *x, int incx) { |
| VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<uint64, std::complex<double>, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x, |
| int incx, DeviceMemory<float> *y, int incy) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y, |
| incy); |
| } |
| |
| Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x, |
| int incx, DeviceMemory<double> *y, int incy) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y, |
| incy); |
| } |
| |
| Stream &Stream::ThenBlasSwap(uint64 elem_count, |
| DeviceMemory<std::complex<float>> *x, int incx, |
| DeviceMemory<std::complex<float>> *y, int incy) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y, |
| incy); |
| } |
| |
| Stream &Stream::ThenBlasSwap(uint64 elem_count, |
| DeviceMemory<std::complex<double>> *x, int incx, |
| DeviceMemory<std::complex<double>> *y, int incy) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y, |
| incy); |
| } |
| |
| Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x, |
| int incx, DeviceMemory<int> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx, |
| result); |
| } |
| |
| Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x, |
| int incx, DeviceMemory<int> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx, |
| result); |
| } |
| |
| Stream &Stream::ThenBlasIamax(uint64 elem_count, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, DeviceMemory<int> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<int> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx, |
| result); |
| } |
| |
| Stream &Stream::ThenBlasIamax(uint64 elem_count, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, DeviceMemory<int> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<int> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx, |
| result); |
| } |
| |
| Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x, |
| int incx, DeviceMemory<int> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx, |
| result); |
| } |
| |
| Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x, |
| int incx, DeviceMemory<int> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx, |
| result); |
| } |
| |
| Stream &Stream::ThenBlasIamin(uint64 elem_count, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, DeviceMemory<int> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<int> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx, |
| result); |
| } |
| |
| Stream &Stream::ThenBlasIamin(uint64 elem_count, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, DeviceMemory<int> *result) { |
| VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); |
| |
| ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<int> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx, |
| result); |
| } |
| |
| Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, |
| uint64 kl, uint64 ku, float alpha, |
| const DeviceMemory<float> &a, int lda, |
| const DeviceMemory<float> &x, int incx, float beta, |
| DeviceMemory<float> *y, int incy) { |
| VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx), |
| PARAM(beta), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, float, |
| const DeviceMemory<float> &, int, const DeviceMemory<float> &, |
| int, float, DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha, |
| a, lda, x, incx, beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, |
| uint64 kl, uint64 ku, double alpha, |
| const DeviceMemory<double> &a, int lda, |
| const DeviceMemory<double> &x, int incx, |
| double beta, DeviceMemory<double> *y, int incy) { |
| VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx), |
| PARAM(beta), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, double, |
| const DeviceMemory<double> &, int, const DeviceMemory<double> &, |
| int, double, DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha, |
| a, lda, x, incx, beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, |
| uint64 kl, uint64 ku, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, |
| int lda, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *y, int incy) { |
| VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx), |
| PARAM(beta), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, |
| std::complex<float>, const DeviceMemory<std::complex<float>> &, |
| int, const DeviceMemory<std::complex<float>> &, int, |
| std::complex<float>, DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha, |
| a, lda, x, incx, beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, |
| uint64 kl, uint64 ku, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, |
| int lda, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *y, int incy) { |
| VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx), |
| PARAM(beta), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, |
| std::complex<double>, const DeviceMemory<std::complex<double>> &, |
| int, const DeviceMemory<std::complex<double>> &, int, |
| std::complex<double>, DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha, |
| a, lda, x, incx, beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, |
| float alpha, const DeviceMemory<float> &a, int lda, |
| const DeviceMemory<float> &x, int incx, float beta, |
| DeviceMemory<float> *y, int incy) { |
| VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), |
| PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), |
| PARAM(incy)); |
| |
| ThenBlasImpl<blas::Transpose, uint64, uint64, float, |
| const DeviceMemory<float> &, int, const DeviceMemory<float> &, |
| int, float, DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda, |
| x, incx, beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, |
| double alpha, const DeviceMemory<double> &a, |
| int lda, const DeviceMemory<double> &x, int incx, |
| double beta, DeviceMemory<double> *y, int incy) { |
| VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), |
| PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), |
| PARAM(incy)); |
| |
| ThenBlasImpl<blas::Transpose, uint64, uint64, double, |
| const DeviceMemory<double> &, int, const DeviceMemory<double> &, |
| int, double, DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda, |
| x, incx, beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, |
| int lda, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *y, int incy) { |
| VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), |
| PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), |
| PARAM(incy)); |
| |
| ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<float>, |
| const DeviceMemory<std::complex<float>> &, int, |
| const DeviceMemory<std::complex<float>> &, int, |
| std::complex<float>, DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda, |
| x, incx, beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, |
| int lda, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *y, int incy) { |
| VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), |
| PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), |
| PARAM(incy)); |
| |
| ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<double>, |
| const DeviceMemory<std::complex<double>> &, int, |
| const DeviceMemory<std::complex<double>> &, int, |
| std::complex<double>, DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda, |
| x, incx, beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasGer(uint64 m, uint64 n, float alpha, |
| const DeviceMemory<float> &x, int incx, |
| const DeviceMemory<float> &y, int incy, |
| DeviceMemory<float> *a, int lda) { |
| VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), |
| PARAM(incy), PARAM(a), PARAM(lda)); |
| |
| ThenBlasImpl<uint64, uint64, float, const DeviceMemory<float> &, int, |
| const DeviceMemory<float> &, int, DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y, |
| incy, a, lda); |
| } |
| |
| Stream &Stream::ThenBlasGer(uint64 m, uint64 n, double alpha, |
| const DeviceMemory<double> &x, int incx, |
| const DeviceMemory<double> &y, int incy, |
| DeviceMemory<double> *a, int lda) { |
| VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), |
| PARAM(incy), PARAM(a), PARAM(lda)); |
| |
| ThenBlasImpl<uint64, uint64, double, const DeviceMemory<double> &, int, |
| const DeviceMemory<double> &, int, DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y, |
| incy, a, lda); |
| } |
| |
| Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, |
| const DeviceMemory<std::complex<float>> &y, |
| int incy, DeviceMemory<std::complex<float>> *a, |
| int lda) { |
| VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), |
| PARAM(incy), PARAM(a), PARAM(lda)); |
| |
| ThenBlasImpl<uint64, uint64, std::complex<float>, |
| const DeviceMemory<std::complex<float>> &, int, |
| const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y, |
| incy, a, lda); |
| } |
| |
| Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, |
| const DeviceMemory<std::complex<double>> &y, |
| int incy, DeviceMemory<std::complex<double>> *a, |
| int lda) { |
| VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), |
| PARAM(incy), PARAM(a), PARAM(lda)); |
| |
| ThenBlasImpl<uint64, uint64, std::complex<double>, |
| const DeviceMemory<std::complex<double>> &, int, |
| const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y, |
| incy, a, lda); |
| } |
| |
| Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, |
| const DeviceMemory<std::complex<float>> &y, |
| int incy, DeviceMemory<std::complex<float>> *a, |
| int lda) { |
| VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), |
| PARAM(incy), PARAM(a), PARAM(lda)); |
| |
| ThenBlasImpl<uint64, uint64, std::complex<float>, |
| const DeviceMemory<std::complex<float>> &, int, |
| const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y, |
| incy, a, lda); |
| } |
| |
| Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, |
| const DeviceMemory<std::complex<double>> &y, |
| int incy, DeviceMemory<std::complex<double>> *a, |
| int lda) { |
| VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), |
| PARAM(incy), PARAM(a), PARAM(lda)); |
| |
| ThenBlasImpl<uint64, uint64, std::complex<double>, |
| const DeviceMemory<std::complex<double>> &, int, |
| const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y, |
| incy, a, lda); |
| } |
| |
| Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, |
| int lda, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *y, int incy) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), |
| PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<float>, |
| const DeviceMemory<std::complex<float>> &, int, |
| const DeviceMemory<std::complex<float>> &, int, |
| std::complex<float>, DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda, |
| x, incx, beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, |
| int lda, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *y, int incy) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), |
| PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<double>, |
| const DeviceMemory<std::complex<double>> &, int, |
| const DeviceMemory<std::complex<double>> &, int, |
| std::complex<double>, DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda, |
| x, incx, beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, |
| int lda, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *y, int incy) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), |
| PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>, |
| const DeviceMemory<std::complex<float>> &, int, |
| const DeviceMemory<std::complex<float>> &, int, |
| std::complex<float>, DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x, |
| incx, beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, |
| int lda, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *y, int incy) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), |
| PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>, |
| const DeviceMemory<std::complex<double>> &, int, |
| const DeviceMemory<std::complex<double>> &, int, |
| std::complex<double>, DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x, |
| incx, beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, DeviceMemory<std::complex<float>> *a, |
| int lda) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), |
| PARAM(a), PARAM(lda)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, float, |
| const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a, |
| lda); |
| } |
| |
| Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, DeviceMemory<std::complex<double>> *a, |
| int lda) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), |
| PARAM(a), PARAM(lda)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, double, |
| const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a, |
| lda); |
| } |
| |
| Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, |
| const DeviceMemory<std::complex<float>> &y, |
| int incy, DeviceMemory<std::complex<float>> *a, |
| int lda) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), |
| PARAM(y), PARAM(incy), PARAM(a), PARAM(lda)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>, |
| const DeviceMemory<std::complex<float>> &, int, |
| const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y, |
| incy, a, lda); |
| } |
| |
| Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, |
| const DeviceMemory<std::complex<double>> &y, |
| int incy, DeviceMemory<std::complex<double>> *a, |
| int lda) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), |
| PARAM(y), PARAM(incy), PARAM(a), PARAM(lda)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>, |
| const DeviceMemory<std::complex<double>> &, int, |
| const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y, |
| incy, a, lda); |
| } |
| |
| Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &ap, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *y, int incy) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x), |
| PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>, |
| const DeviceMemory<std::complex<float>> &, |
| const DeviceMemory<std::complex<float>> &, int, |
| std::complex<float>, DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx, |
| beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &ap, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *y, int incy) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x), |
| PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>, |
| const DeviceMemory<std::complex<double>> &, |
| const DeviceMemory<std::complex<double>> &, int, |
| std::complex<double>, DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx, |
| beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, DeviceMemory<std::complex<float>> *ap) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), |
| PARAM(ap)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, float, |
| const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<std::complex<float>> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap); |
| } |
| |
| Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, DeviceMemory<std::complex<double>> *ap) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), |
| PARAM(ap)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, double, |
| const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<std::complex<double>> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap); |
| } |
| |
| Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &x, |
| int incx, |
| const DeviceMemory<std::complex<float>> &y, |
| int incy, DeviceMemory<std::complex<float>> *ap) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), |
| PARAM(y), PARAM(incy), PARAM(ap)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>, |
| const DeviceMemory<std::complex<float>> &, int, |
| const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<std::complex<float>> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y, |
| incy, ap); |
| } |
| |
| Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &x, |
| int incx, |
| const DeviceMemory<std::complex<double>> &y, |
| int incy, DeviceMemory<std::complex<double>> *ap) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), |
| PARAM(y), PARAM(incy), PARAM(ap)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>, |
| const DeviceMemory<std::complex<double>> &, int, |
| const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<std::complex<double>> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y, |
| incy, ap); |
| } |
| |
| Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, |
| float alpha, const DeviceMemory<float> &a, int lda, |
| const DeviceMemory<float> &x, int incx, float beta, |
| DeviceMemory<float> *y, int incy) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), |
| PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, uint64, float, |
| const DeviceMemory<float> &, int, const DeviceMemory<float> &, |
| int, float, DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda, |
| x, incx, beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, |
| double alpha, const DeviceMemory<double> &a, |
| int lda, const DeviceMemory<double> &x, int incx, |
| double beta, DeviceMemory<double> *y, int incy) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), |
| PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, uint64, double, |
| const DeviceMemory<double> &, int, const DeviceMemory<double> &, |
| int, double, DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda, |
| x, incx, beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha, |
| const DeviceMemory<float> &ap, |
| const DeviceMemory<float> &x, int incx, float beta, |
| DeviceMemory<float> *y, int incy) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x), |
| PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &, |
| const DeviceMemory<float> &, int, float, DeviceMemory<float> *, |
| int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx, |
| beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha, |
| const DeviceMemory<double> &ap, |
| const DeviceMemory<double> &x, int incx, |
| double beta, DeviceMemory<double> *y, int incy) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x), |
| PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &, |
| const DeviceMemory<double> &, int, double, |
| DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx, |
| beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha, |
| const DeviceMemory<float> &x, int incx, |
| DeviceMemory<float> *ap) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), |
| PARAM(ap)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &, |
| int, DeviceMemory<float> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap); |
| } |
| |
| Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha, |
| const DeviceMemory<double> &x, int incx, |
| DeviceMemory<double> *ap) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), |
| PARAM(ap)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &, |
| int, DeviceMemory<double> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap); |
| } |
| |
| Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha, |
| const DeviceMemory<float> &x, int incx, |
| const DeviceMemory<float> &y, int incy, |
| DeviceMemory<float> *ap) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), |
| PARAM(y), PARAM(incy), PARAM(ap)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &, |
| int, const DeviceMemory<float> &, int, DeviceMemory<float> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y, |
| incy, ap); |
| } |
| |
| Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha, |
| const DeviceMemory<double> &x, int incx, |
| const DeviceMemory<double> &y, int incy, |
| DeviceMemory<double> *ap) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), |
| PARAM(y), PARAM(incy), PARAM(ap)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &, |
| int, const DeviceMemory<double> &, int, DeviceMemory<double> *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y, |
| incy, ap); |
| } |
| |
| Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha, |
| const DeviceMemory<float> &a, int lda, |
| const DeviceMemory<float> &x, int incx, float beta, |
| DeviceMemory<float> *y, int incy) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), |
| PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &, |
| int, const DeviceMemory<float> &, int, float, |
| DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x, |
| incx, beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha, |
| const DeviceMemory<double> &a, int lda, |
| const DeviceMemory<double> &x, int incx, |
| double beta, DeviceMemory<double> *y, int incy) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), |
| PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &, |
| int, const DeviceMemory<double> &, int, double, |
| DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x, |
| incx, beta, y, incy); |
| } |
| |
| Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha, |
| const DeviceMemory<float> &x, int incx, |
| DeviceMemory<float> *a, int lda) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), |
| PARAM(a), PARAM(lda)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &, |
| int, DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a, |
| lda); |
| } |
| |
| Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha, |
| const DeviceMemory<double> &x, int incx, |
| DeviceMemory<double> *a, int lda) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), |
| PARAM(a), PARAM(lda)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &, |
| int, DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a, |
| lda); |
| } |
| |
| Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha, |
| const DeviceMemory<float> &x, int incx, |
| const DeviceMemory<float> &y, int incy, |
| DeviceMemory<float> *a, int lda) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), |
| PARAM(y), PARAM(incy), PARAM(a), PARAM(lda)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &, |
| int, const DeviceMemory<float> &, int, DeviceMemory<float> *, |
| int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y, |
| incy, a, lda); |
| } |
| |
| Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha, |
| const DeviceMemory<double> &x, int incx, |
| const DeviceMemory<double> &y, int incy, |
| DeviceMemory<double> *a, int lda) { |
| VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), |
| PARAM(y), PARAM(incy), PARAM(a), PARAM(lda)); |
| |
| ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &, |
| int, const DeviceMemory<double> &, int, DeviceMemory<double> *, |
| int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y, |
| incy, a, lda); |
| } |
| |
| Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, uint64 k, |
| const DeviceMemory<float> &a, int lda, |
| DeviceMemory<float> *x, int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k), |
| PARAM(a), PARAM(lda), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *, |
| int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a, |
| lda, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, uint64 k, |
| const DeviceMemory<double> &a, int lda, |
| DeviceMemory<double> *x, int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k), |
| PARAM(a), PARAM(lda), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| uint64, const DeviceMemory<double> &, int, |
| DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a, |
| lda, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, uint64 k, |
| const DeviceMemory<std::complex<float>> &a, |
| int lda, DeviceMemory<std::complex<float>> *x, |
| int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k), |
| PARAM(a), PARAM(lda), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| uint64, const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a, |
| lda, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, uint64 k, |
| const DeviceMemory<std::complex<double>> &a, |
| int lda, DeviceMemory<std::complex<double>> *x, |
| int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k), |
| PARAM(a), PARAM(lda), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| uint64, const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a, |
| lda, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, uint64 k, |
| const DeviceMemory<float> &a, int lda, |
| DeviceMemory<float> *x, int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k), |
| PARAM(a), PARAM(lda), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *, |
| int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a, |
| lda, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, uint64 k, |
| const DeviceMemory<double> &a, int lda, |
| DeviceMemory<double> *x, int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k), |
| PARAM(a), PARAM(lda), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| uint64, const DeviceMemory<double> &, int, |
| DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a, |
| lda, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, uint64 k, |
| const DeviceMemory<std::complex<float>> &a, |
| int lda, DeviceMemory<std::complex<float>> *x, |
| int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k), |
| PARAM(a), PARAM(lda), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| uint64, const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a, |
| lda, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, uint64 k, |
| const DeviceMemory<std::complex<double>> &a, |
| int lda, DeviceMemory<std::complex<double>> *x, |
| int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k), |
| PARAM(a), PARAM(lda), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| uint64, const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a, |
| lda, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, |
| const DeviceMemory<float> &ap, |
| DeviceMemory<float> *x, int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap), |
| PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| const DeviceMemory<float> &, DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x, |
| incx); |
| } |
| |
| Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, |
| const DeviceMemory<double> &ap, |
| DeviceMemory<double> *x, int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap), |
| PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| const DeviceMemory<double> &, DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x, |
| incx); |
| } |
| |
| Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, |
| const DeviceMemory<std::complex<float>> &ap, |
| DeviceMemory<std::complex<float>> *x, int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap), |
| PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| const DeviceMemory<std::complex<float>> &, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x, |
| incx); |
| } |
| |
| Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, |
| const DeviceMemory<std::complex<double>> &ap, |
| DeviceMemory<std::complex<double>> *x, int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap), |
| PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| const DeviceMemory<std::complex<double>> &, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x, |
| incx); |
| } |
| |
| Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, |
| const DeviceMemory<float> &ap, |
| DeviceMemory<float> *x, int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap), |
| PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| const DeviceMemory<float> &, DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x, |
| incx); |
| } |
| |
| Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, |
| const DeviceMemory<double> &ap, |
| DeviceMemory<double> *x, int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap), |
| PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| const DeviceMemory<double> &, DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x, |
| incx); |
| } |
| |
| Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, |
| const DeviceMemory<std::complex<float>> &ap, |
| DeviceMemory<std::complex<float>> *x, int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap), |
| PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| const DeviceMemory<std::complex<float>> &, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x, |
| incx); |
| } |
| |
| Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, |
| const DeviceMemory<std::complex<double>> &ap, |
| DeviceMemory<std::complex<double>> *x, int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap), |
| PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| const DeviceMemory<std::complex<double>> &, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x, |
| incx); |
| } |
| |
| Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, |
| const DeviceMemory<float> &a, int lda, |
| DeviceMemory<float> *x, int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a), |
| PARAM(lda), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| const DeviceMemory<float> &, int, DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a, |
| lda, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, |
| const DeviceMemory<double> &a, int lda, |
| DeviceMemory<double> *x, int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a), |
| PARAM(lda), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| const DeviceMemory<double> &, int, DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a, |
| lda, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, |
| const DeviceMemory<std::complex<float>> &a, |
| int lda, DeviceMemory<std::complex<float>> *x, |
| int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a), |
| PARAM(lda), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a, |
| lda, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, |
| const DeviceMemory<std::complex<double>> &a, |
| int lda, DeviceMemory<std::complex<double>> *x, |
| int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a), |
| PARAM(lda), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a, |
| lda, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, |
| const DeviceMemory<float> &a, int lda, |
| DeviceMemory<float> *x, int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a), |
| PARAM(lda), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| const DeviceMemory<float> &, int, DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a, |
| lda, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, |
| const DeviceMemory<double> &a, int lda, |
| DeviceMemory<double> *x, int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a), |
| PARAM(lda), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| const DeviceMemory<double> &, int, DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a, |
| lda, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, |
| const DeviceMemory<std::complex<float>> &a, |
| int lda, DeviceMemory<std::complex<float>> *x, |
| int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a), |
| PARAM(lda), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a, |
| lda, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans, |
| blas::Diagonal diag, uint64 n, |
| const DeviceMemory<std::complex<double>> &a, |
| int lda, DeviceMemory<std::complex<double>> *x, |
| int incx) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a), |
| PARAM(lda), PARAM(x), PARAM(incx)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, |
| const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a, |
| lda, x, incx); |
| } |
| |
| Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, |
| uint64 m, uint64 n, uint64 k, float alpha, |
| const DeviceMemory<Eigen::half> &a, int lda, |
| const DeviceMemory<Eigen::half> &b, int ldb, |
| float beta, DeviceMemory<Eigen::half> *c, |
| int ldc) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc)); |
| |
| ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, |
| const DeviceMemory<Eigen::half> &, int, |
| const DeviceMemory<Eigen::half> &, int, float, |
| DeviceMemory<Eigen::half> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k, |
| alpha, a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, |
| uint64 m, uint64 n, uint64 k, float alpha, |
| const DeviceMemory<float> &a, int lda, |
| const DeviceMemory<float> &b, int ldb, float beta, |
| DeviceMemory<float> *c, int ldc) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc)); |
| |
| ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, |
| const DeviceMemory<float> &, int, const DeviceMemory<float> &, |
| int, float, DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k, |
| alpha, a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, |
| uint64 m, uint64 n, uint64 k, double alpha, |
| const DeviceMemory<double> &a, int lda, |
| const DeviceMemory<double> &b, int ldb, |
| double beta, DeviceMemory<double> *c, int ldc) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc)); |
| |
| ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double, |
| const DeviceMemory<double> &, int, const DeviceMemory<double> &, |
| int, double, DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k, |
| alpha, a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, |
| uint64 m, uint64 n, uint64 k, |
| std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, |
| int lda, |
| const DeviceMemory<std::complex<float>> &b, |
| int ldb, std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *c, int ldc) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc)); |
| |
| ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, |
| std::complex<float>, const DeviceMemory<std::complex<float>> &, |
| int, const DeviceMemory<std::complex<float>> &, int, |
| std::complex<float>, DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k, |
| alpha, a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, |
| uint64 m, uint64 n, uint64 k, |
| std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, |
| int lda, |
| const DeviceMemory<std::complex<double>> &b, |
| int ldb, std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *c, int ldc) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc)); |
| |
| ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, |
| std::complex<double>, const DeviceMemory<std::complex<double>> &, |
| int, const DeviceMemory<std::complex<double>> &, int, |
| std::complex<double>, DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k, |
| alpha, a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| namespace { |
| // Like ThenBlasImpl, except this expects the last argument of blas_func to be a |
| // blas::ProfileResult*. This functor doesn't put the stream into an error |
| // state if the op fails and the profile result is non-null. Instead, the |
| // error-ness is returned in the profile result itself. |
| template <typename... Args> |
| struct ThenBlasWithProfileImpl { |
| Stream &operator()(Stream *stream, |
| bool (blas::BlasSupport::*blas_func)( |
| Stream *, Args..., blas::ProfileResult *), |
| Args... args, blas::ProfileResult *profile_result) { |
| ThenBlasImpl<Args..., blas::ProfileResult *> Runner; |
| bool record_error = profile_result == nullptr; |
| return Runner.Run(stream, blas_func, record_error, args..., profile_result); |
| } |
| }; |
| } // anonymous namespace |
| |
| Stream &Stream::ThenBlasGemvWithProfiling( |
| blas::Transpose trans, uint64 m, uint64 n, float alpha, |
| const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, |
| int incx, float beta, DeviceMemory<float> *y, int incy, |
| blas::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), |
| PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), |
| PARAM(incy)); |
| |
| ThenBlasWithProfileImpl< |
| blas::Transpose, uint64, uint64, float, const DeviceMemory<float> &, int, |
| const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n, |
| alpha, a, lda, x, incx, beta, y, incy, output_profile_result); |
| } |
| |
| Stream &Stream::ThenBlasGemvWithProfiling( |
| blas::Transpose trans, uint64 m, uint64 n, double alpha, |
| const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, |
| int incx, double beta, DeviceMemory<double> *y, int incy, |
| blas::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), |
| PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), |
| PARAM(incy)); |
| |
| ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, double, |
| const DeviceMemory<double> &, int, |
| const DeviceMemory<double> &, int, double, |
| DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n, |
| alpha, a, lda, x, incx, beta, y, incy, output_profile_result); |
| } |
| |
| Stream &Stream::ThenBlasGemvWithProfiling( |
| blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &x, int incx, |
| std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy, |
| blas::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), |
| PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), |
| PARAM(incy)); |
| |
| ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<float>, |
| const DeviceMemory<std::complex<float>> &, int, |
| const DeviceMemory<std::complex<float>> &, int, |
| std::complex<float>, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n, |
| alpha, a, lda, x, incx, beta, y, incy, output_profile_result); |
| } |
| |
| Stream &Stream::ThenBlasGemvWithProfiling( |
| blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &x, int incx, |
| std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy, |
| blas::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), |
| PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), |
| PARAM(incy)); |
| |
| ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<double>, |
| const DeviceMemory<std::complex<double>> &, int, |
| const DeviceMemory<std::complex<double>> &, int, |
| std::complex<double>, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n, |
| alpha, a, lda, x, incx, beta, y, incy, output_profile_result); |
| } |
| |
| Stream &Stream::ThenBlasGemmWithProfiling( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda, |
| const DeviceMemory<Eigen::half> &b, int ldb, float beta, |
| DeviceMemory<Eigen::half> *c, int ldc, |
| blas::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc)); |
| |
| ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, |
| uint64, float, const DeviceMemory<Eigen::half> &, int, |
| const DeviceMemory<Eigen::half> &, int, float, |
| DeviceMemory<Eigen::half> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb, |
| m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, |
| output_profile_result); |
| } |
| |
| Stream &Stream::ThenBlasGemmWithProfiling( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, float alpha, const DeviceMemory<float> &a, int lda, |
| const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c, |
| int ldc, blas::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc)); |
| |
| ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, |
| uint64, float, const DeviceMemory<float> &, int, |
| const DeviceMemory<float> &, int, float, |
| DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb, |
| m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, |
| output_profile_result); |
| } |
| |
| Stream &Stream::ThenBlasGemmWithProfiling( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, double alpha, const DeviceMemory<double> &a, int lda, |
| const DeviceMemory<double> &b, int ldb, double beta, |
| DeviceMemory<double> *c, int ldc, |
| blas::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc)); |
| |
| ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, |
| uint64, double, const DeviceMemory<double> &, int, |
| const DeviceMemory<double> &, int, double, |
| DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb, |
| m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, |
| output_profile_result); |
| } |
| |
| Stream &Stream::ThenBlasGemmWithProfiling( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &b, int ldb, |
| std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, |
| blas::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc)); |
| |
| ThenBlasWithProfileImpl< |
| blas::Transpose, blas::Transpose, uint64, uint64, uint64, |
| std::complex<float>, const DeviceMemory<std::complex<float>> &, int, |
| const DeviceMemory<std::complex<float>> &, int, std::complex<float>, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb, |
| m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, |
| output_profile_result); |
| } |
| |
| Stream &Stream::ThenBlasGemmWithProfiling( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &b, int ldb, |
| std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, |
| blas::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc)); |
| |
| ThenBlasWithProfileImpl< |
| blas::Transpose, blas::Transpose, uint64, uint64, uint64, |
| std::complex<double>, const DeviceMemory<std::complex<double>> &, int, |
| const DeviceMemory<std::complex<double>> &, int, std::complex<double>, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb, |
| m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, |
| output_profile_result); |
| } |
| |
| Stream &Stream::ThenBlasGemmWithAlgorithm( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha, |
| const DeviceMemory<Eigen::half> &a, int lda, |
| const DeviceMemory<Eigen::half> &b, int ldb, |
| const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c, |
| int ldc, blas::ComputationType computation_type, |
| blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), |
| PARAM(algorithm)); |
| |
| ThenBlasWithProfileImpl< |
| blas::Transpose, blas::Transpose, uint64, uint64, uint64, |
| const HostOrDeviceScalar<Eigen::half> &, |
| const DeviceMemory<Eigen::half> &, int, const DeviceMemory<Eigen::half> &, |
| int, const HostOrDeviceScalar<Eigen::half> &, DeviceMemory<Eigen::half> *, |
| int, blas::ComputationType, blas::AlgorithmType> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, |
| m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, |
| algorithm, output_profile_result); |
| } |
| |
| Stream &Stream::ThenBlasGemmWithAlgorithm( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, const HostOrDeviceScalar<int> &alpha, const DeviceMemory<int8> &a, |
| int lda, const DeviceMemory<int8> &b, int ldb, |
| const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c, int ldc, |
| blas::ComputationType computation_type, blas::AlgorithmType algorithm, |
| blas::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), |
| PARAM(algorithm)); |
| |
| ThenBlasWithProfileImpl< |
| blas::Transpose, blas::Transpose, uint64, uint64, uint64, |
| const HostOrDeviceScalar<int> &, const DeviceMemory<int8> &, int, |
| const DeviceMemory<int8> &, int, const HostOrDeviceScalar<int> &, |
| DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, |
| m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, |
| algorithm, output_profile_result); |
| } |
| |
| Stream &Stream::ThenBlasGemmWithAlgorithm( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, const HostOrDeviceScalar<float> &alpha, |
| const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b, |
| int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c, |
| int ldc, blas::ComputationType computation_type, |
| blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), |
| PARAM(algorithm)); |
| |
| ThenBlasWithProfileImpl< |
| blas::Transpose, blas::Transpose, uint64, uint64, uint64, |
| const HostOrDeviceScalar<float> &, const DeviceMemory<float> &, int, |
| const DeviceMemory<float> &, int, const HostOrDeviceScalar<float> &, |
| DeviceMemory<float> *, int, blas::ComputationType, blas::AlgorithmType> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, |
| m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, |
| algorithm, output_profile_result); |
| } |
| |
| Stream &Stream::ThenBlasGemmWithAlgorithm( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, const HostOrDeviceScalar<double> &alpha, |
| const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, |
| int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c, |
| int ldc, blas::ComputationType computation_type, |
| blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), |
| PARAM(algorithm)); |
| |
| ThenBlasWithProfileImpl< |
| blas::Transpose, blas::Transpose, uint64, uint64, uint64, |
| const HostOrDeviceScalar<double> &, const DeviceMemory<double> &, int, |
| const DeviceMemory<double> &, int, const HostOrDeviceScalar<double> &, |
| DeviceMemory<double> *, int, blas::ComputationType, blas::AlgorithmType> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, |
| m, n, k, HostOrDeviceScalar<double>(alpha), a, lda, b, ldb, |
| HostOrDeviceScalar<double>(beta), c, ldc, computation_type, |
| algorithm, output_profile_result); |
| } |
| |
| Stream &Stream::ThenBlasGemmWithAlgorithm( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, |
| const DeviceMemory<std::complex<float>> &b, int ldb, |
| const HostOrDeviceScalar<std::complex<float>> &beta, |
| DeviceMemory<std::complex<float>> *c, int ldc, |
| blas::ComputationType computation_type, blas::AlgorithmType algorithm, |
| blas::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), |
| PARAM(algorithm)); |
| |
| ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, |
| uint64, |
| const HostOrDeviceScalar<std::complex<float>> &, |
| const DeviceMemory<std::complex<float>> &, int, |
| const DeviceMemory<std::complex<float>> &, int, |
| const HostOrDeviceScalar<std::complex<float>> &, |
| DeviceMemory<std::complex<float>> *, int, |
| blas::ComputationType, blas::AlgorithmType> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, |
| m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, |
| algorithm, output_profile_result); |
| } |
| |
| Stream &Stream::ThenBlasGemmWithAlgorithm( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, |
| const DeviceMemory<std::complex<double>> &b, int ldb, |
| const HostOrDeviceScalar<std::complex<double>> &beta, |
| DeviceMemory<std::complex<double>> *c, int ldc, |
| blas::ComputationType computation_type, blas::AlgorithmType algorithm, |
| blas::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), |
| PARAM(algorithm)); |
| |
| ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, |
| uint64, |
| const HostOrDeviceScalar<std::complex<double>> &, |
| const DeviceMemory<std::complex<double>> &, int, |
| const DeviceMemory<std::complex<double>> &, int, |
| const HostOrDeviceScalar<std::complex<double>> &, |
| DeviceMemory<std::complex<double>> *, int, |
| blas::ComputationType, blas::AlgorithmType> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, |
| m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, |
| algorithm, output_profile_result); |
| } |
| |
| Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m, |
| uint64 n, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, |
| int lda, |
| const DeviceMemory<std::complex<float>> &b, |
| int ldb, std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *c, int ldc) { |
| VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha), |
| PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), |
| PARAM(ldc)); |
| |
| ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, |
| std::complex<float>, const DeviceMemory<std::complex<float>> &, |
| int, const DeviceMemory<std::complex<float>> &, int, |
| std::complex<float>, DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a, |
| lda, b, ldb, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m, |
| uint64 n, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, |
| int lda, |
| const DeviceMemory<std::complex<double>> &b, |
| int ldb, std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *c, int ldc) { |
| VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha), |
| PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), |
| PARAM(ldc)); |
| |
| ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, |
| std::complex<double>, const DeviceMemory<std::complex<double>> &, |
| int, const DeviceMemory<std::complex<double>> &, int, |
| std::complex<double>, DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a, |
| lda, b, ldb, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, |
| uint64 n, uint64 k, float alpha, |
| const DeviceMemory<std::complex<float>> &a, |
| int lda, float beta, |
| DeviceMemory<std::complex<float>> *c, int ldc) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), |
| PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float, |
| const DeviceMemory<std::complex<float>> &, int, float, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a, |
| lda, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, |
| uint64 n, uint64 k, double alpha, |
| const DeviceMemory<std::complex<double>> &a, |
| int lda, double beta, |
| DeviceMemory<std::complex<double>> *c, int ldc) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), |
| PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double, |
| const DeviceMemory<std::complex<double>> &, int, double, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a, |
| lda, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, |
| uint64 n, uint64 k, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, |
| int lda, |
| const DeviceMemory<std::complex<float>> &b, |
| int ldb, float beta, |
| DeviceMemory<std::complex<float>> *c, int ldc) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), |
| PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), |
| PARAM(ldc)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, |
| std::complex<float>, const DeviceMemory<std::complex<float>> &, |
| int, const DeviceMemory<std::complex<float>> &, int, float, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha, |
| a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, |
| uint64 n, uint64 k, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, |
| int lda, |
| const DeviceMemory<std::complex<double>> &b, |
| int ldb, double beta, |
| DeviceMemory<std::complex<double>> *c, int ldc) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), |
| PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), |
| PARAM(ldc)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, |
| std::complex<double>, const DeviceMemory<std::complex<double>> &, |
| int, const DeviceMemory<std::complex<double>> &, int, double, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha, |
| a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m, |
| uint64 n, float alpha, |
| const DeviceMemory<float> &a, int lda, |
| const DeviceMemory<float> &b, int ldb, float beta, |
| DeviceMemory<float> *c, int ldc) { |
| VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha), |
| PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), |
| PARAM(ldc)); |
| |
| ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, float, |
| const DeviceMemory<float> &, int, const DeviceMemory<float> &, |
| int, float, DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a, |
| lda, b, ldb, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m, |
| uint64 n, double alpha, |
| const DeviceMemory<double> &a, int lda, |
| const DeviceMemory<double> &b, int ldb, |
| double beta, DeviceMemory<double> *c, int ldc) { |
| VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha), |
| PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), |
| PARAM(ldc)); |
| |
| ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, double, |
| const DeviceMemory<double> &, int, const DeviceMemory<double> &, |
| int, double, DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a, |
| lda, b, ldb, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m, |
| uint64 n, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, |
| int lda, |
| const DeviceMemory<std::complex<float>> &b, |
| int ldb, std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *c, int ldc) { |
| VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha), |
| PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), |
| PARAM(ldc)); |
| |
| ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, |
| std::complex<float>, const DeviceMemory<std::complex<float>> &, |
| int, const DeviceMemory<std::complex<float>> &, int, |
| std::complex<float>, DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a, |
| lda, b, ldb, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m, |
| uint64 n, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, |
| int lda, |
| const DeviceMemory<std::complex<double>> &b, |
| int ldb, std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *c, int ldc) { |
| VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha), |
| PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), |
| PARAM(ldc)); |
| |
| ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, |
| std::complex<double>, const DeviceMemory<std::complex<double>> &, |
| int, const DeviceMemory<std::complex<double>> &, int, |
| std::complex<double>, DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a, |
| lda, b, ldb, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, |
| uint64 n, uint64 k, float alpha, |
| const DeviceMemory<float> &a, int lda, float beta, |
| DeviceMemory<float> *c, int ldc) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), |
| PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float, |
| const DeviceMemory<float> &, int, float, DeviceMemory<float> *, |
| int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a, |
| lda, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, |
| uint64 n, uint64 k, double alpha, |
| const DeviceMemory<double> &a, int lda, |
| double beta, DeviceMemory<double> *c, int ldc) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), |
| PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double, |
| const DeviceMemory<double> &, int, double, |
| DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a, |
| lda, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, |
| uint64 n, uint64 k, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, |
| int lda, std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *c, int ldc) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), |
| PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, |
| std::complex<float>, const DeviceMemory<std::complex<float>> &, |
| int, std::complex<float>, DeviceMemory<std::complex<float>> *, |
| int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a, |
| lda, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, |
| uint64 n, uint64 k, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, |
| int lda, std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *c, int ldc) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), |
| PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, |
| std::complex<double>, const DeviceMemory<std::complex<double>> &, |
| int, std::complex<double>, DeviceMemory<std::complex<double>> *, |
| int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a, |
| lda, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, |
| uint64 n, uint64 k, float alpha, |
| const DeviceMemory<float> &a, int lda, |
| const DeviceMemory<float> &b, int ldb, float beta, |
| DeviceMemory<float> *c, int ldc) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), |
| PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), |
| PARAM(ldc)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float, |
| const DeviceMemory<float> &, int, const DeviceMemory<float> &, |
| int, float, DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha, |
| a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, |
| uint64 n, uint64 k, double alpha, |
| const DeviceMemory<double> &a, int lda, |
| const DeviceMemory<double> &b, int ldb, |
| double beta, DeviceMemory<double> *c, int ldc) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), |
| PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), |
| PARAM(ldc)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double, |
| const DeviceMemory<double> &, int, const DeviceMemory<double> &, |
| int, double, DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha, |
| a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, |
| uint64 n, uint64 k, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, |
| int lda, |
| const DeviceMemory<std::complex<float>> &b, |
| int ldb, std::complex<float> beta, |
| DeviceMemory<std::complex<float>> *c, int ldc) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), |
| PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), |
| PARAM(ldc)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, |
| std::complex<float>, const DeviceMemory<std::complex<float>> &, |
| int, const DeviceMemory<std::complex<float>> &, int, |
| std::complex<float>, DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha, |
| a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, |
| uint64 n, uint64 k, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, |
| int lda, |
| const DeviceMemory<std::complex<double>> &b, |
| int ldb, std::complex<double> beta, |
| DeviceMemory<std::complex<double>> *c, int ldc) { |
| VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), |
| PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), |
| PARAM(ldc)); |
| |
| ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, |
| std::complex<double>, const DeviceMemory<std::complex<double>> &, |
| int, const DeviceMemory<std::complex<double>> &, int, |
| std::complex<double>, DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha, |
| a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, |
| uint64 m, uint64 n, float alpha, |
| const DeviceMemory<float> &a, int lda, |
| DeviceMemory<float> *b, int ldb) { |
| VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), |
| PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); |
| |
| ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal, |
| uint64, uint64, float, const DeviceMemory<float> &, int, |
| DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m, |
| n, alpha, a, lda, b, ldb); |
| } |
| |
| Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, |
| uint64 m, uint64 n, double alpha, |
| const DeviceMemory<double> &a, int lda, |
| DeviceMemory<double> *b, int ldb) { |
| VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), |
| PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); |
| |
| ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal, |
| uint64, uint64, double, const DeviceMemory<double> &, int, |
| DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m, |
| n, alpha, a, lda, b, ldb); |
| } |
| |
| Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, |
| uint64 m, uint64 n, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, |
| int lda, DeviceMemory<std::complex<float>> *b, |
| int ldb) { |
| VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), |
| PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); |
| |
| ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal, |
| uint64, uint64, std::complex<float>, |
| const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m, |
| n, alpha, a, lda, b, ldb); |
| } |
| |
| Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, |
| uint64 m, uint64 n, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, |
| int lda, DeviceMemory<std::complex<double>> *b, |
| int ldb) { |
| VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), |
| PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); |
| |
| ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal, |
| uint64, uint64, std::complex<double>, |
| const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m, |
| n, alpha, a, lda, b, ldb); |
| } |
| |
| Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, |
| uint64 m, uint64 n, float alpha, |
| const DeviceMemory<float> &a, int lda, |
| DeviceMemory<float> *b, int ldb) { |
| VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), |
| PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); |
| |
| ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal, |
| uint64, uint64, float, const DeviceMemory<float> &, int, |
| DeviceMemory<float> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m, |
| n, alpha, a, lda, b, ldb); |
| } |
| |
| Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, |
| uint64 m, uint64 n, double alpha, |
| const DeviceMemory<double> &a, int lda, |
| DeviceMemory<double> *b, int ldb) { |
| VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), |
| PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); |
| |
| ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal, |
| uint64, uint64, double, const DeviceMemory<double> &, int, |
| DeviceMemory<double> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m, |
| n, alpha, a, lda, b, ldb); |
| } |
| |
| Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, |
| uint64 m, uint64 n, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, |
| int lda, DeviceMemory<std::complex<float>> *b, |
| int ldb) { |
| VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), |
| PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); |
| |
| ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal, |
| uint64, uint64, std::complex<float>, |
| const DeviceMemory<std::complex<float>> &, int, |
| DeviceMemory<std::complex<float>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m, |
| n, alpha, a, lda, b, ldb); |
| } |
| |
| Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, |
| blas::Transpose transa, blas::Diagonal diag, |
| uint64 m, uint64 n, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, |
| int lda, DeviceMemory<std::complex<double>> *b, |
| int ldb) { |
| VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), |
| PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); |
| |
| ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal, |
| uint64, uint64, std::complex<double>, |
| const DeviceMemory<std::complex<double>> &, int, |
| DeviceMemory<std::complex<double>> *, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m, |
| n, alpha, a, lda, b, ldb); |
| } |
| |
| Stream &Stream::ThenBlasGemmBatched( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, float alpha, |
| const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda, |
| const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta, |
| const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc, |
| int batch_count) { |
| return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, |
| b, ldb, beta, c, ldc, batch_count, |
| /*scratch_allocator=*/nullptr); |
| } |
| |
| Stream &Stream::ThenBlasGemmBatchedWithScratch( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, float alpha, |
| const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda, |
| const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta, |
| const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc, |
| int batch_count, ScratchAllocator *scratch_allocator) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); |
| |
| ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, |
| const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int, |
| const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int, |
| float, const port::ArraySlice<DeviceMemory<Eigen::half> *> &, |
| int, int, ScratchAllocator *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, |
| k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, |
| scratch_allocator); |
| } |
| |
| Stream &Stream::ThenBlasGemmBatched( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a, |
| int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, |
| float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, |
| int batch_count) { |
| return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, |
| b, ldb, beta, c, ldc, batch_count, |
| /*scratch_allocator=*/nullptr); |
| } |
| |
| Stream &Stream::ThenBlasGemmBatchedWithScratch( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a, |
| int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, |
| float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, |
| int batch_count, ScratchAllocator *scratch_allocator) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); |
| |
| ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, |
| const port::ArraySlice<DeviceMemory<float> *> &, int, |
| const port::ArraySlice<DeviceMemory<float> *> &, int, float, |
| const port::ArraySlice<DeviceMemory<float> *> &, int, int, |
| ScratchAllocator *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, |
| k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, |
| scratch_allocator); |
| } |
| |
| Stream &Stream::ThenBlasGemmBatched( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a, |
| int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, |
| double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc, |
| int batch_count) { |
| return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, |
| b, ldb, beta, c, ldc, batch_count, |
| /*scratch_allocator=*/nullptr); |
| } |
| |
| Stream &Stream::ThenBlasGemmBatchedWithScratch( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a, |
| int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, |
| double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc, |
| int batch_count, ScratchAllocator *scratch_allocator) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); |
| |
| ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double, |
| const port::ArraySlice<DeviceMemory<double> *> &, int, |
| const port::ArraySlice<DeviceMemory<double> *> &, int, double, |
| const port::ArraySlice<DeviceMemory<double> *> &, int, int, |
| ScratchAllocator *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, |
| k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, |
| scratch_allocator); |
| } |
| |
| Stream &Stream::ThenBlasGemmBatched( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, std::complex<float> alpha, |
| const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, |
| const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, |
| std::complex<float> beta, |
| const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, |
| int batch_count) { |
| return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, |
| b, ldb, beta, c, ldc, batch_count, |
| /*scratch_allocator=*/nullptr); |
| } |
| |
| Stream &Stream::ThenBlasGemmBatchedWithScratch( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, std::complex<float> alpha, |
| const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, |
| const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, |
| std::complex<float> beta, |
| const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, |
| int batch_count, ScratchAllocator *scratch_allocator) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); |
| |
| ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, |
| std::complex<float>, |
| const port::ArraySlice<DeviceMemory<std::complex<float>> *> &, |
| int, |
| const port::ArraySlice<DeviceMemory<std::complex<float>> *> &, |
| int, std::complex<float>, |
| const port::ArraySlice<DeviceMemory<std::complex<float>> *> &, |
| int, int, ScratchAllocator *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, |
| k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, |
| scratch_allocator); |
| } |
| |
| Stream &Stream::ThenBlasGemmBatched( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, std::complex<double> alpha, |
| const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda, |
| const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb, |
| std::complex<double> beta, |
| const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, |
| int batch_count) { |
| return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, |
| b, ldb, beta, c, ldc, batch_count, |
| /*scratch_allocator=*/nullptr); |
| } |
| |
| Stream &Stream::ThenBlasGemmBatchedWithScratch( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, std::complex<double> alpha, |
| const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda, |
| const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb, |
| std::complex<double> beta, |
| const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, |
| int batch_count, ScratchAllocator *scratch_allocator) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), |
| PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); |
| |
| ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, |
| std::complex<double>, |
| const port::ArraySlice<DeviceMemory<std::complex<double>> *> &, |
| int, |
| const port::ArraySlice<DeviceMemory<std::complex<double>> *> &, |
| int, std::complex<double>, |
| const port::ArraySlice<DeviceMemory<std::complex<double>> *> &, |
| int, int, ScratchAllocator *> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, |
| k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, |
| scratch_allocator); |
| } |
| |
| Stream &Stream::ThenBlasGemmStridedBatched( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda, |
| int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b, |
| float beta, DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c, |
| int batch_count) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), |
| PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), |
| PARAM(stride_c), PARAM(batch_count)); |
| |
| ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, |
| const DeviceMemory<Eigen::half> &, int, int64, |
| const DeviceMemory<Eigen::half> &, int, int64, float, |
| DeviceMemory<Eigen::half> *, int, int64, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, |
| transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, |
| c, ldc, stride_c, batch_count); |
| } |
| |
| Stream &Stream::ThenBlasGemmStridedBatched( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, float alpha, const DeviceMemory<float> &a, int lda, |
| int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, |
| float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, |
| int batch_count) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), |
| PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), |
| PARAM(stride_c), PARAM(batch_count)); |
| |
| ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, |
| const DeviceMemory<float> &, int, int64, |
| const DeviceMemory<float> &, int, int64, float, |
| DeviceMemory<float> *, int, int64, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, |
| transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, |
| c, ldc, stride_c, batch_count); |
| } |
| |
| Stream &Stream::ThenBlasGemmStridedBatched( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, double alpha, const DeviceMemory<double> &a, int lda, |
| int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, |
| double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, |
| int batch_count) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), |
| PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), |
| PARAM(stride_c), PARAM(batch_count)); |
| |
| ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double, |
| const DeviceMemory<double> &, int, int64, |
| const DeviceMemory<double> &, int, int64, double, |
| DeviceMemory<double> *, int, int64, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, |
| transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, |
| c, ldc, stride_c, batch_count); |
| } |
| |
| Stream &Stream::ThenBlasGemmStridedBatched( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, std::complex<float> alpha, |
| const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, |
| const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, |
| std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, |
| int64 stride_c, int batch_count) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), |
| PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), |
| PARAM(stride_c), PARAM(batch_count)); |
| |
| ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, |
| std::complex<float>, const DeviceMemory<std::complex<float>> &, |
| int, int64, const DeviceMemory<std::complex<float>> &, int, |
| int64, std::complex<float>, DeviceMemory<std::complex<float>> *, |
| int, int64, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, |
| transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, |
| c, ldc, stride_c, batch_count); |
| } |
| |
| Stream &Stream::ThenBlasGemmStridedBatched( |
| blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |
| uint64 k, std::complex<double> alpha, |
| const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, |
| const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, |
| std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, |
| int64 stride_c, int batch_count) { |
| VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), |
| PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), |
| PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), |
| PARAM(stride_c), PARAM(batch_count)); |
| |
| ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, |
| std::complex<double>, const DeviceMemory<std::complex<double>> &, |
| int, int64, const DeviceMemory<std::complex<double>> &, int, |
| int64, std::complex<double>, |
| DeviceMemory<std::complex<double>> *, int, int64, int> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, |
| transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, |
| c, ldc, stride_c, batch_count); |
| } |
| |
| template <typename ABType, typename CType> |
| Stream &Stream::ThenBlasLtMatmulImpl( |
| const blas::IBlasLtMatmulPlan *plan, const HostOrDeviceScalar<CType> &alpha, |
| const DeviceMemory<ABType> &a, const DeviceMemory<ABType> &b, |
| const HostOrDeviceScalar<CType> &beta, DeviceMemory<CType> *c, |
| ScratchAllocator *scratch_allocator, |
| const blas::IBlasLtMatmulAlgorithm *algorithm, |
| const DeviceMemory<CType> &bias, |
| blas::ProfileResult *output_profile_result) { |
| VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta), |
| PARAM(c), PARAM(algorithm), PARAM(bias)); |
| |
| ThenBlasWithProfileImpl< |
| const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<CType> &, |
| const DeviceMemory<ABType> &, const DeviceMemory<ABType> &, |
| const HostOrDeviceScalar<CType> &, DeviceMemory<CType> *, |
| ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *, |
| const DeviceMemory<CType> &> |
| impl; |
| return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta, |
| c, scratch_allocator, algorithm, bias, output_profile_result); |
| } |
| |
| // Explicit template instantiations for each supported type combination. |
| template Stream &Stream::ThenBlasLtMatmulImpl<int8, int32>( |
| const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<int32> &, |
| const DeviceMemory<int8> &, const DeviceMemory<int8> &, |
| const HostOrDeviceScalar<int32> &, DeviceMemory<int32> *, |
| ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *, |
| const DeviceMemory<int32> &, blas::ProfileResult *); |
| |
| template Stream &Stream::ThenBlasLtMatmulImpl<Eigen::half, Eigen::half>( |
| const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<Eigen::half> &, |
| const DeviceMemory<Eigen::half> &, const DeviceMemory<Eigen::half> &, |
| const HostOrDeviceScalar<Eigen::half> &, DeviceMemory<Eigen::half> *, |
| ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *, |
| const DeviceMemory<Eigen::half> &, blas::ProfileResult *); |
| |
| template Stream &Stream::ThenBlasLtMatmulImpl<float, float>( |
| const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<float> &, |
| const DeviceMemory<float> &, const DeviceMemory<float> &, |
| const HostOrDeviceScalar<float> &, DeviceMemory<float> *, |
| ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *, |
| const DeviceMemory<float> &, blas::ProfileResult *); |
| |
| template Stream &Stream::ThenBlasLtMatmulImpl<double, double>( |
| const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<double> &, |
| const DeviceMemory<double> &, const DeviceMemory<double> &, |
| const HostOrDeviceScalar<double> &, DeviceMemory<double> *, |
| ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *, |
| const DeviceMemory<double> &, blas::ProfileResult *); |
| |
| template Stream & |
| Stream::ThenBlasLtMatmulImpl<std::complex<float>, std::complex<float>>( |
| const blas::IBlasLtMatmulPlan *, |
| const HostOrDeviceScalar<std::complex<float>> &, |
| const DeviceMemory<std::complex<float>> &, |
| const DeviceMemory<std::complex<float>> &, |
| const HostOrDeviceScalar<std::complex<float>> &, |
| DeviceMemory<std::complex<float>> *, ScratchAllocator *, |
| const blas::IBlasLtMatmulAlgorithm *, |
| const DeviceMemory<std::complex<float>> &, blas::ProfileResult *); |
| |
| template Stream & |
| Stream::ThenBlasLtMatmulImpl<std::complex<double>, std::complex<double>>( |
| const blas::IBlasLtMatmulPlan *, |
| const HostOrDeviceScalar<std::complex<double>> &, |
| const DeviceMemory<std::complex<double>> &, |
| const DeviceMemory<std::complex<double>> &, |
| const HostOrDeviceScalar<std::complex<double>> &, |
| DeviceMemory<std::complex<double>> *, ScratchAllocator *, |
| const blas::IBlasLtMatmulAlgorithm *, |
| const DeviceMemory<std::complex<double>> &, blas::ProfileResult *); |
| |
| Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) { |
| VLOG_CALL(PARAM(seed), PARAM(seed_bytes)); |
| |
| if (rng::RngSupport *rng = parent_->AsRng()) { |
| CheckError(rng->SetSeed(this, seed, seed_bytes)); |
| } else { |
| SetError(); |
| LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG"; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) { |
| VLOG_CALL(PARAM(values)); |
| |
| if (rng::RngSupport *rng = parent_->AsRng()) { |
| CheckError(rng->DoPopulateRandUniform(this, values)); |
| } else { |
| SetError(); |
| LOG(INFO) << DebugStreamPointers() |
| << " attempting to perform RNG operation using StreamExecutor" |
| " without RNG support."; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenPopulateRandGaussian(float mean, float sd, |
| DeviceMemory<float> *values) { |
| VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values)); |
| |
| if (rng::RngSupport *rng = parent_->AsRng()) { |
| CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values)); |
| } else { |
| SetError(); |
| LOG(INFO) << DebugStreamPointers() |
| << " attempting to perform RNG operation using StreamExecutor" |
| " without RNG support."; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenPopulateRandGaussian(double mean, double sd, |
| DeviceMemory<double> *values) { |
| VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values)); |
| |
| if (rng::RngSupport *rng = parent_->AsRng()) { |
| CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values)); |
| } else { |
| SetError(); |
| LOG(INFO) << DebugStreamPointers() |
| << " attempting to perform RNG operation using StreamExecutor" |
| " without RNG support."; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) { |
| VLOG_CALL(PARAM(values)); |
| |
| if (rng::RngSupport *rng = parent_->AsRng()) { |
| CheckError(rng->DoPopulateRandUniform(this, values)); |
| } else { |
| SetError(); |
| LOG(INFO) << DebugStreamPointers() |
| << " attempting to perform RNG operation using StreamExecutor" |
| " without RNG support."; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenPopulateRandUniform( |
| DeviceMemory<std::complex<float>> *values) { |
| VLOG_CALL(PARAM(values)); |
| |
| if (rng::RngSupport *rng = parent_->AsRng()) { |
| CheckError(rng->DoPopulateRandUniform(this, values)); |
| } else { |
| SetError(); |
| LOG(INFO) << DebugStreamPointers() |
| << " attempting to perform RNG operation using StreamExecutor" |
| " without RNG support."; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenPopulateRandUniform( |
| DeviceMemory<std::complex<double>> *values) { |
| VLOG_CALL(PARAM(values)); |
| |
| if (rng::RngSupport *rng = parent_->AsRng()) { |
| CheckError(rng->DoPopulateRandUniform(this, values)); |
| } else { |
| SetError(); |
| LOG(INFO) << DebugStreamPointers() |
| << " attempting to perform RNG operation using StreamExecutor" |
| " without RNG support."; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src, |
| uint64 size) { |
| VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size)); |
| |
| CheckError(parent_->Memcpy(this, host_dst, gpu_src, size)); |
| return *this; |
| } |
| |
| Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src, |
| uint64 size) { |
| VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size)); |
| |
| CheckError(parent_->Memcpy(this, gpu_dst, host_src, size)); |
| return *this; |
| } |
| |
| Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, |
| const DeviceMemoryBase &gpu_src, uint64 size) { |
| VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size)); |
| |
| CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size)); |
| return *this; |
| } |
| |
| Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) { |
| VLOG_CALL(PARAM(location), PARAM(size)); |
| |
| CheckStatus(parent_->MemZero(this, location, size)); |
| return *this; |
| } |
| |
| Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern, |
| uint64 size) { |
| VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size)); |
| |
| CheckStatus(parent_->Memset32(this, location, pattern, size)); |
| return *this; |
| } |
| |
| Stream &Stream::ThenRnnForward( |
| const dnn::RnnDescriptor &rnn_desc, |
| const dnn::RnnSequenceTensorDescriptor &input_desc, |
| const DeviceMemory<Eigen::half> &input_data, |
| const dnn::RnnStateTensorDescriptor &input_h_desc, |
| const DeviceMemory<Eigen::half> &input_h_data, |
| const dnn::RnnStateTensorDescriptor &input_c_desc, |
| const DeviceMemory<Eigen::half> &input_c_data, |
| const DeviceMemory<Eigen::half> ¶ms, |
| const dnn::RnnSequenceTensorDescriptor &output_desc, |
| DeviceMemory<Eigen::half> *output_data, |
| const dnn::RnnStateTensorDescriptor &output_h_desc, |
| DeviceMemory<Eigen::half> *output_h_data, |
| const dnn::RnnStateTensorDescriptor &output_c_desc, |
| DeviceMemory<Eigen::half> *output_c_data, bool is_training, |
| ScratchAllocator *reserve_space_allocator, |
| ScratchAllocator *workspace_allocator, |
| dnn::ProfileResult *output_profile_result) { |
| // TODO(zhengxq): add VLOG PARAM calls. |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| auto status = dnn->DoRnnForward( |
| this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, |
| input_c_desc, input_c_data, params, output_desc, output_data, |
| output_h_desc, output_h_data, output_c_desc, output_c_data, is_training, |
| reserve_space_allocator, workspace_allocator, output_profile_result); |
| if (!status && !output_profile_result) { |
| SetError(); |
| } |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenRnnForward( |
| const dnn::RnnDescriptor &rnn_desc, |
| const dnn::RnnSequenceTensorDescriptor &input_desc, |
| const DeviceMemory<float> &input_data, |
| const dnn::RnnStateTensorDescriptor &input_h_desc, |
| const DeviceMemory<float> &input_h_data, |
| const dnn::RnnStateTensorDescriptor &input_c_desc, |
| const DeviceMemory<float> &input_c_data, const DeviceMemory<float> ¶ms, |
| const dnn::RnnSequenceTensorDescriptor &output_desc, |
| DeviceMemory<float> *output_data, |
| const dnn::RnnStateTensorDescriptor &output_h_desc, |
| DeviceMemory<float> *output_h_data, |
| const dnn::RnnStateTensorDescriptor &output_c_desc, |
| DeviceMemory<float> *output_c_data, bool is_training, |
| ScratchAllocator *reserve_space_allocator, |
| ScratchAllocator *workspace_allocator, |
| dnn::ProfileResult *output_profile_result) { |
| // TODO(zhengxq): add VLOG PARAM calls. |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| auto status = dnn->DoRnnForward( |
| this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, |
| input_c_desc, input_c_data, params, output_desc, output_data, |
| output_h_desc, output_h_data, output_c_desc, output_c_data, is_training, |
| reserve_space_allocator, workspace_allocator, output_profile_result); |
| if (!status && !output_profile_result) { |
| SetError(); |
| } |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenRnnForward( |
| const dnn::RnnDescriptor &rnn_desc, |
| const dnn::RnnSequenceTensorDescriptor &input_desc, |
| const DeviceMemory<double> &input_data, |
| const dnn::RnnStateTensorDescriptor &input_h_desc, |
| const DeviceMemory<double> &input_h_data, |
| const dnn::RnnStateTensorDescriptor &input_c_desc, |
| const DeviceMemory<double> &input_c_data, |
| const DeviceMemory<double> ¶ms, |
| const dnn::RnnSequenceTensorDescriptor &output_desc, |
| DeviceMemory<double> *output_data, |
| const dnn::RnnStateTensorDescriptor &output_h_desc, |
| DeviceMemory<double> *output_h_data, |
| const dnn::RnnStateTensorDescriptor &output_c_desc, |
| DeviceMemory<double> *output_c_data, bool is_training, |
| ScratchAllocator *reserve_space_allocator, |
| ScratchAllocator *workspace_allocator, |
| dnn::ProfileResult *output_profile_result) { |
| // TODO(zhengxq): add VLOG PARAM calls. |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| auto status = dnn->DoRnnForward( |
| this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, |
| input_c_desc, input_c_data, params, output_desc, output_data, |
| output_h_desc, output_h_data, output_c_desc, output_c_data, is_training, |
| reserve_space_allocator, workspace_allocator, output_profile_result); |
| if (!status && !output_profile_result) { |
| SetError(); |
| } |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenRnnBackward( |
| const dnn::RnnDescriptor &rnn_desc, |
| const dnn::RnnSequenceTensorDescriptor &input_desc, |
| const DeviceMemory<Eigen::half> &input_data, |
| const dnn::RnnStateTensorDescriptor &input_h_desc, |
| const DeviceMemory<Eigen::half> &input_h_data, |
| const dnn::RnnStateTensorDescriptor &input_c_desc, |
| const DeviceMemory<Eigen::half> &input_c_data, |
| const DeviceMemory<Eigen::half> ¶ms, |
| const dnn::RnnSequenceTensorDescriptor &output_desc, |
| const DeviceMemory<Eigen::half> &output_data, |
| const dnn::RnnStateTensorDescriptor &output_h_desc, |
| const DeviceMemory<Eigen::half> &output_h_data, |
| const dnn::RnnStateTensorDescriptor &output_c_desc, |
| const DeviceMemory<Eigen::half> &output_c_data, |
| const DeviceMemory<Eigen::half> &output_backprop_data, |
| const DeviceMemory<Eigen::half> &output_h_backprop_data, |
| const DeviceMemory<Eigen::half> &output_c_backprop_data, |
| DeviceMemory<Eigen::half> *input_backprop_data, |
| DeviceMemory<Eigen::half> *input_h_backprop_data, |
| DeviceMemory<Eigen::half> *input_c_backprop_data, |
| DeviceMemory<Eigen::half> *params_backprop_data, |
| DeviceMemory<uint8> *reserve_space_data, |
| ScratchAllocator *workspace_allocator, |
| dnn::ProfileResult *output_profile_result) { |
| // TODO(zhengxq): add VLOG PARAM calls. |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| auto status = dnn->DoRnnBackward( |
| this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, |
| input_c_desc, input_c_data, params, output_desc, output_data, |
| output_h_desc, output_h_data, output_c_desc, output_c_data, |
| output_backprop_data, output_h_backprop_data, output_c_backprop_data, |
| input_backprop_data, input_h_backprop_data, input_c_backprop_data, |
| params_backprop_data, reserve_space_data, workspace_allocator, |
| output_profile_result); |
| if (!status && !output_profile_result) { |
| SetError(); |
| } |
| } else { |
| SetError(); |
| LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support"; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenRnnBackward( |
| const dnn::RnnDescriptor &rnn_desc, |
| const dnn::RnnSequenceTensorDescriptor &input_desc, |
| const DeviceMemory<float> &input_data, |
| const dnn::RnnStateTensorDescriptor &input_h_desc, |
| const DeviceMemory<float> &input_h_data, |
| const dnn::RnnStateTensorDescriptor &input_c_desc, |
| const DeviceMemory<float> &input_c_data, const DeviceMemory<float> ¶ms, |
| const dnn::RnnSequenceTensorDescriptor &output_desc, |
| const DeviceMemory<float> &output_data, |
| const dnn::RnnStateTensorDescriptor &output_h_desc, |
| const DeviceMemory<float> &output_h_data, |
| const dnn::RnnStateTensorDescriptor &output_c_desc, |
| const DeviceMemory<float> &output_c_data, |
| const DeviceMemory<float> &output_backprop_data, |
| const DeviceMemory<float> &output_h_backprop_data, |
| const DeviceMemory<float> &output_c_backprop_data, |
| DeviceMemory<float> *input_backprop_data, |
| DeviceMemory<float> *input_h_backprop_data, |
| DeviceMemory<float> *input_c_backprop_data, |
| DeviceMemory<float> *params_backprop_data, |
| DeviceMemory<uint8> *reserve_space_data, |
| ScratchAllocator *workspace_allocator, |
| dnn::ProfileResult *output_profile_result) { |
| // TODO(zhengxq): add VLOG PARAM calls. |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| auto status = dnn->DoRnnBackward( |
| this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, |
| input_c_desc, input_c_data, params, output_desc, output_data, |
| output_h_desc, output_h_data, output_c_desc, output_c_data, |
| output_backprop_data, output_h_backprop_data, output_c_backprop_data, |
| input_backprop_data, input_h_backprop_data, input_c_backprop_data, |
| params_backprop_data, reserve_space_data, workspace_allocator, |
| output_profile_result); |
| if (!status && !output_profile_result) { |
| SetError(); |
| } |
| } else { |
| SetError(); |
| LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support"; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenRnnBackward( |
| const dnn::RnnDescriptor &rnn_desc, |
| const dnn::RnnSequenceTensorDescriptor &input_desc, |
| const DeviceMemory<double> &input_data, |
| const dnn::RnnStateTensorDescriptor &input_h_desc, |
| const DeviceMemory<double> &input_h_data, |
| const dnn::RnnStateTensorDescriptor &input_c_desc, |
| const DeviceMemory<double> &input_c_data, |
| const DeviceMemory<double> ¶ms, |
| const dnn::RnnSequenceTensorDescriptor &output_desc, |
| const DeviceMemory<double> &output_data, |
| const dnn::RnnStateTensorDescriptor &output_h_desc, |
| const DeviceMemory<double> &output_h_data, |
| const dnn::RnnStateTensorDescriptor &output_c_desc, |
| const DeviceMemory<double> &output_c_data, |
| const DeviceMemory<double> &output_backprop_data, |
| const DeviceMemory<double> &output_h_backprop_data, |
| const DeviceMemory<double> &output_c_backprop_data, |
| DeviceMemory<double> *input_backprop_data, |
| DeviceMemory<double> *input_h_backprop_data, |
| DeviceMemory<double> *input_c_backprop_data, |
| DeviceMemory<double> *params_backprop_data, |
| DeviceMemory<uint8> *reserve_space_data, |
| ScratchAllocator *workspace_allocator, |
| dnn::ProfileResult *output_profile_result) { |
| // TODO(zhengxq): add VLOG PARAM calls. |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| auto status = dnn->DoRnnBackward( |
| this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, |
| input_c_desc, input_c_data, params, output_desc, output_data, |
| output_h_desc, output_h_data, output_c_desc, output_c_data, |
| output_backprop_data, output_h_backprop_data, output_c_backprop_data, |
| input_backprop_data, input_h_backprop_data, input_c_backprop_data, |
| params_backprop_data, reserve_space_data, workspace_allocator, |
| output_profile_result); |
| if (!status && !output_profile_result) { |
| SetError(); |
| } |
| } else { |
| SetError(); |
| LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support"; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc, |
| const DeviceMemory<float> &probs_data, |
| absl::Span<const int> labels_data, |
| absl::Span<const int> labels_lengths_data, |
| absl::Span<const int> input_lengths_data, |
| DeviceMemory<float> *costs_data, |
| const dnn::RnnStateTensorDescriptor &grads_desc, |
| DeviceMemory<float> *grads_data, |
| ScratchAllocator *workspace_allocator) { |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| DeviceMemory<uint8> scratch_memory; |
| int ctc_loss_algo_id; |
| auto status = |
| dnn->PrepareForCtcLoss(this, probs_desc, probs_data, grads_desc, |
| labels_data, labels_lengths_data, |
| input_lengths_data, workspace_allocator, |
| &scratch_memory, &ctc_loss_algo_id) |
| .ok(); |
| if (status) { |
| status = dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data, |
| labels_lengths_data, input_lengths_data, |
| costs_data, grads_desc, grads_data, |
| &scratch_memory, ctc_loss_algo_id); |
| } |
| if (!status) { |
| SetError(); |
| } |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc, |
| dnn::DataType input_type, |
| const DeviceMemoryBase &input_data, |
| const dnn::BatchDescriptor &output_desc, |
| dnn::DataType output_type, float scale, |
| DeviceMemoryBase *output_data) { |
| VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data), |
| PARAM(output_desc), PARAM(output_type), PARAM(scale), |
| PARAM(output_data)); |
| if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
| CheckError(dnn->DoTransformTensor(this, input_desc, input_type, input_data, |
| output_desc, output_type, scale, |
| output_data)); |
| } else { |
| SetErrorAndLogNoDnnSupport(); |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenDoHostCallback(std::function<void()> callback) { |
| VLOG_CALL(PARAM(callback)); |
| |
| if (!ok()) { |
| LOG(INFO) << DebugStreamPointers() |
| << " was in error state before adding host callback"; |
| } |
| CheckError(parent_->HostCallback(this, std::move(callback))); |
| return *this; |
| } |
| |
| Stream &Stream::ThenDoHostCallbackWithStatus( |
| std::function<port::Status()> callback) { |
| VLOG_CALL(PARAM(callback)); |
| |
| if (!ok()) { |
| LOG(INFO) << DebugStreamPointers() |
| << " was in error state before adding host callback"; |
| } |
| CheckError(parent_->HostCallback(this, std::move(callback))); |
| return *this; |
| } |
| |
| Stream &Stream::ThenRunAfterNextBlockHostUntilDone( |
| std::function<void()> callback) { |
| VLOG_CALL(PARAM(callback)); |
| |
| if (!ok()) { |
| LOG(INFO) << DebugStreamPointers() |
| << " was in error state before adding callback to be run after " |
| "next block-host-until-done."; |
| } |
| absl::MutexLock lock(&mu_); |
| after_block_host_until_done_callbacks_.push_back(std::move(callback)); |
| return *this; |
| } |
| |
| Stream &Stream::ThenFft(fft::Plan *plan, |
| const DeviceMemory<std::complex<float>> &input, |
| DeviceMemory<std::complex<float>> *output) { |
| VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); |
| |
| if (fft::FftSupport *fft = parent_->AsFft()) { |
| CheckError(fft->DoFft(this, plan, input, output)); |
| } else { |
| SetError(); |
| LOG(INFO) << DebugStreamPointers() |
| << " attempting to perform FFT operation using StreamExecutor" |
| " without FFT support"; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenFft(fft::Plan *plan, |
| const DeviceMemory<std::complex<double>> &input, |
| DeviceMemory<std::complex<double>> *output) { |
| VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); |
| |
| if (fft::FftSupport *fft = parent_->AsFft()) { |
| CheckError(fft->DoFft(this, plan, input, output)); |
| } else { |
| SetError(); |
| LOG(INFO) << DebugStreamPointers() |
| << " attempting to perform FFT operation using StreamExecutor" |
| " without FFT support"; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input, |
| DeviceMemory<std::complex<float>> *output) { |
| VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); |
| |
| if (fft::FftSupport *fft = parent_->AsFft()) { |
| CheckError(fft->DoFft(this, plan, input, output)); |
| } else { |
| SetError(); |
| LOG(INFO) << DebugStreamPointers() |
| << " attempting to perform FFT operation using StreamExecutor" |
| " without FFT support"; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input, |
| DeviceMemory<std::complex<double>> *output) { |
| VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); |
| |
| if (fft::FftSupport *fft = parent_->AsFft()) { |
| CheckError(fft->DoFft(this, plan, input, output)); |
| } else { |
| SetError(); |
| LOG(INFO) << DebugStreamPointers() |
| << " attempting to perform FFT operation using StreamExecutor" |
| " without FFT support"; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenFft(fft::Plan *plan, |
| const DeviceMemory<std::complex<float>> &input, |
| DeviceMemory<float> *output) { |
| VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); |
| |
| if (fft::FftSupport *fft = parent_->AsFft()) { |
| CheckError(fft->DoFft(this, plan, input, output)); |
| } else { |
| SetError(); |
| LOG(INFO) << DebugStreamPointers() |
| << " attempting to perform FFT operation using StreamExecutor" |
| " without FFT support"; |
| } |
| return *this; |
| } |
| |
| Stream &Stream::ThenFft(fft::Plan *plan, |
| const DeviceMemory<std::complex<double>> &input, |
| DeviceMemory<double> *output) { |
| VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); |
| |
| if (fft::FftSupport *fft = parent_->AsFft()) { |
| CheckError(fft->DoFft(this, plan, input, output)); |
| } else { |
| SetError(); |
| LOG(INFO) << DebugStreamPointers() |
| << " attempting to perform FFT operation using StreamExecutor" |
| " without FFT support"; |
| } |
| return *this; |
| } |
| |
| // It looks confusing, but all this is doing is inserting a callback at the |
| // present point in the stream to then enqueue a task on the host executor. |
| Stream &Stream::ThenEnqueueOnBackgroundThread( |
| std::function<void(StreamExecutor *)> task) { |
| VLOG_CALL(PARAM(task)); |
| |
| StreamExecutor *stream_executor = this->parent_; |
| std::function<void()> bound_task = std::bind(task, stream_executor); |
| |
| return ThenDoHostCallback([stream_executor, bound_task]() { |
| stream_executor->EnqueueOnBackgroundThread(bound_task); |
| }); |
| } |
| |
| port::Status Stream::BlockHostUntilDone() { |
| VLOG_CALL(); |
| |
| if (!ok()) { |
| port::Status status = port::Status( |
| port::error::INTERNAL, |
| "stream did not block host until done; was already in an error state"); |
| LOG(INFO) << DebugStreamPointers() << " " << status; |
| return status; |
| } |
| |
| temporary_memory_manager_.DeallocateFinalizedTemporaries(); |
| |
| port::Status error = parent_->BlockHostUntilDone(this); |
| CheckError(error.ok()); |
| |
| RunAfterBlockHostUntilDoneCallbacks(); |
| return error; |
| } |
| |
| void Stream::RunAfterBlockHostUntilDoneCallbacks() { |
| std::vector<std::function<void()>> callbacks; |
| { |
| absl::MutexLock lock(&mu_); |
| std::swap(callbacks, after_block_host_until_done_callbacks_); |
| } |
| for (const auto &fn : callbacks) { |
| fn(); |
| } |
| } |
| |
| std::string Stream::DebugStreamPointers() const { |
| // Relies on the ToVlogString(const void*) overload above. |
| return absl::StrCat("[stream=", ToVlogString(this), |
| ",impl=", ToVlogString(implementation_.get()), "]"); |
| } |
| |
| void Stream::CheckStatus(port::Status status) { |
| if (status.ok()) { |
| return; |
| } |
| LOG(ERROR) << status; |
| absl::MutexLock lock(&mu_); |
| status_ = status; |
| } |
| |
| } // namespace stream_executor |