blob: f658ff7420a002ddf9e123373e352b09d691a644 [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "absl/strings/str_cat.h"
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/host_or_device_scalar.h"
#include "tensorflow/stream_executor/lib/stacktrace.h"
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/rng.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
namespace stream_executor {
namespace {
// Code to turn parameters to functions on stream into strings that
// will be VLOG'ed. We need overloads, instead of
// e.g. BatchDescriptorToVlogString(), as the code that calls these
// functions does not know what the type of the parameter is.
string ToVlogString(const dnn::BatchDescriptor &descriptor) {
return descriptor.ToShortString();
}
string ToVlogString(const dnn::FilterDescriptor &descriptor) {
return descriptor.ToShortString();
}
string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
return descriptor.ToShortString();
}
string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
return descriptor.ToShortString();
}
string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
return descriptor.ToShortString();
}
string ToVlogString(dnn::ActivationMode mode) {
return dnn::ActivationModeString(mode);
}
string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
return algo_config.ToString();
}
string ToVlogString(dnn::ElementwiseOperation op) {
return dnn::ElementwiseOperationString(op);
}
string ToVlogString(dnn::QuantizedActivationMode mode) {
return dnn::QuantizedActivationModeString(mode);
}
string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
string ToVlogString(blas::UpperLower ul) { return blas::UpperLowerString(ul); }
string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
string ToVlogString(blas::Side s) { return blas::SideString(s); }
string ToVlogString(blas::ComputationType ty) {
return blas::ComputationTypeString(ty);
}
string ToVlogString(const void *ptr) {
if (ptr == nullptr) {
return "null";
}
// StrCat does not convert pointers to text.
std::ostringstream out;
out << ptr;
return out.str();
}
template <class T>
string ToVlogString(const std::complex<T> &c) {
// StrCat does not convert std::complex to text.
std::ostringstream out;
out << c;
return out.str();
}
template <class T>
string ToVlogString(const std::function<T> &f) {
return f == nullptr ? "null" : "<non-null function>";
}
string ToVlogString(const DeviceMemoryBase &memory) {
return ToVlogString(memory.opaque());
}
string ToVlogString(const DeviceMemoryBase *memory) {
return memory == nullptr ? "null" : ToVlogString(*memory);
}
string ToVlogString(const Eigen::half &h) {
return absl::StrCat(static_cast<float>(h));
}
string ToVlogString(int i) { return absl::StrCat(i); }
string ToVlogString(uint32 i) { return absl::StrCat(i); }
string ToVlogString(uint64 i) { return absl::StrCat(i); }
string ToVlogString(int64 i) { return absl::StrCat(i); }
string ToVlogString(float f) { return absl::StrCat(f); }
string ToVlogString(double d) { return absl::StrCat(d); }
template <typename T>
string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) {
if (memory_or_constant.is_pointer()) {
return ToVlogString(memory_or_constant.pointer());
}
return ToVlogString(memory_or_constant.value());
}
template <class T>
string ToVlogString(port::ArraySlice<T> elements) {
string str = absl::StrCat(
ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
elements.size(), "]{");
const char *separator = "";
size_t max_to_show = std::numeric_limits<size_t>::max();
if (!VLOG_IS_ON(2)) {
max_to_show = 5;
} else if (!VLOG_IS_ON(3)) {
max_to_show = 20;
} else if (!VLOG_IS_ON(11)) {
max_to_show = 1000;
}
for (size_t i = 0; i < elements.size(); ++i) {
if (i == max_to_show) {
str += ", ...";
break;
}
absl::StrAppend(&str, separator, ToVlogString(elements[i]));
separator = ", ";
}
str += "}";
return str;
}
template <class T>
string ToVlogString(port::MutableArraySlice<T> elements) {
return ToVlogString(port::ArraySlice<T>(elements));
}
string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
switch (depth_to_space_layout) {
case dnn::DepthToSpaceLayout::DepthHeightWidth:
return "DepthToSpaceLayout::DepthHeightWidth";
}
return "unknown DepthToSpaceLayout";
}
string ToVlogString(dnn::DataType data_type) {
switch (data_type) {
case dnn::DataType::kFloat:
return "dnn::DataType::kFloat";
case dnn::DataType::kDouble:
return "dnn::DataType::kDouble";
case dnn::DataType::kHalf:
return "dnn::DataType::kHalf";
case dnn::DataType::kInt8:
return "dnn::DataType::kInt8";
case dnn::DataType::kInt32:
return "dnn::DataType::kInt32";
default:
return "unknown DataType";
}
}
// Used together with PARAM to VLOG calls made to the stream. Intended
// to be used like this:
//
// VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)});
//
// where a and b are the parameters to MyFunction.
//
// See VLOG_CALL for a short-hand for this. This way of doing it saves
// a tremendous amount of boilerplate code given how many functions
// there are on Stream and how many parameters they each have.
string CallStr(const char *function_name, Stream *stream,
std::vector<std::pair<const char *, string>> params) {
// Do not call this function unless VLOG is on since just
// constructing all the strings in params is expensive.
CHECK(VLOG_IS_ON(1));
string str = absl::StrCat(stream->DebugStreamPointers(),
" Called Stream::", function_name, "(");
const char *separator = "";
for (const auto &param : params) {
absl::StrAppend(&str, separator, param.first, "=", param.second);
separator = ", ";
}
absl::StrAppend(&str, ")");
if (VLOG_IS_ON(10)) {
absl::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
}
return str;
}
// Use this macro to avoid having to type every parameter twice to log
// it with VLOG and CallStr.
#define PARAM(parameter) \
{ #parameter, ToVlogString(parameter) }
// Use this macro to avoid having to type out the name of each
// function and to save some boilerplate. Intended to be used like this:
//
// VLOG_CALL(PARAM(a), PARAM(b))
//
// This saves a tremendous amount of boilerplate compared to the alternative:
//
// VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a)
// << ", b=" << ToVlogString(b);
//
// Note here that most of the parameter names are not short and that
// most of the functions take many more than 2 parameters.
#define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__})
} // namespace
Stream::Stream(StreamExecutor *parent)
: parent_(parent),
implementation_(parent->implementation()->GetStreamImplementation()),
allocated_(false),
ok_(false),
temporary_memory_manager_(this) {
VLOG_CALL(PARAM(parent));
}
Stream::Stream(StreamExecutor *parent,
internal::StreamInterface *implementation)
: parent_(parent),
implementation_(implementation),
allocated_(false),
ok_(false),
temporary_memory_manager_(this) {
VLOG_CALL(PARAM(parent), PARAM(implementation));
}
Stream::~Stream() {
VLOG_CALL();
// Ensure the stream is completed.
auto status = BlockHostUntilDone();
if (!status.ok()) {
LOG(WARNING) << "Error blocking host until done in stream destructor: "
<< status;
}
temporary_memory_manager_.ForceDeallocateAll();
RunAfterBlockHostUntilDoneCallbacks();
if (allocated_) {
parent_->DeallocateStream(this);
}
}
port::Status Stream::RefreshStatus() {
port::Status status = parent_->GetStatus(this);
CheckStatus(status);
return status;
}
Stream &Stream::Init() {
VLOG_CALL();
absl::MutexLock lock(&mu_);
CHECK_EQ(false, allocated_)
<< "stream appears to already have been initialized";
CHECK(!ok_) << "stream should be in !ok() state pre-initialization";
if (parent_->AllocateStream(this)) {
// Successful initialization!
allocated_ = true;
ok_ = true;
} else {
LOG(ERROR) << "failed to allocate stream during initialization";
}
return *this;
}
Stream &Stream::InitTimer(Timer *timer) {
VLOG_CALL(PARAM(timer));
if (ok()) {
CheckError(parent_->AllocateTimer(timer));
} else {
LOG(INFO) << "did not allocate timer: " << timer;
}
return *this;
}
Stream &Stream::InitWithTimer(Timer *timer) {
VLOG_CALL(PARAM(timer));
return Init().InitTimer(timer);
}
Stream &Stream::ThenRecordEvent(Event *event) {
VLOG_CALL(PARAM(event));
port::Status status = parent_->RecordEvent(this, event);
if (!status.ok()) {
LOG(ERROR) << "Error recording event in stream: " << status.error_message()
<< "; not marking stream as bad, as the Event object may be "
<< "at fault. Monitor for further errors.";
}
return *this;
}
Stream &Stream::ThenBatchNormalizationForward(
const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
const DeviceMemory<float> &offset,
const DeviceMemory<float> &estimated_mean,
const DeviceMemory<float> &estimated_variance,
const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
const double exponential_average_factor,
dnn::ActivationMode activation_mode, DeviceMemory<float> *y,
DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
bool is_training,
std::function<const DeviceMemory<float> &()> var_to_inv_var,
std::function<void()> inv_var_to_var,
ScratchAllocator *reserve_space_allocator,
ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoBatchNormalizationForward(
this, x, scale, offset, estimated_mean, estimated_variance,
side_input, x_desc, scale_offset_desc, epsilon,
exponential_average_factor, activation_mode, y, batch_mean, batch_var,
saved_mean, saved_inv_var, is_training, reserve_space_allocator,
workspace_allocator, std::move(var_to_inv_var),
std::move(inv_var_to_var)));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenBatchNormalizationBackward(
const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
DeviceMemory<float> *offset_backprop,
DeviceMemory<uint8> *reserve_space_data,
ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
PARAM(scale_backprop), PARAM(offset_backprop));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoBatchNormalizationBackward(
this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
epsilon, x_backprop, scale_backprop, offset_backprop,
reserve_space_data, workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenBatchNormalizationForward(
const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
const DeviceMemory<float> &offset,
const DeviceMemory<float> &estimated_mean,
const DeviceMemory<float> &estimated_variance,
const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
const double exponential_average_factor,
dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y,
DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
bool is_training,
std::function<const DeviceMemory<float> &()> var_to_inv_var,
std::function<void()> inv_var_to_var,
ScratchAllocator *reserve_space_allocator,
ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoBatchNormalizationForward(
this, x, scale, offset, estimated_mean, estimated_variance,
side_input, x_desc, scale_offset_desc, epsilon,
exponential_average_factor, activation_mode, y, batch_mean, batch_var,
saved_mean, saved_inv_var, is_training, reserve_space_allocator,
workspace_allocator, std::move(var_to_inv_var),
std::move(inv_var_to_var)));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenBatchNormalizationBackward(
const DeviceMemory<Eigen::half> &y_backprop,
const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
const dnn::BatchDescriptor &x_desc,
const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
DeviceMemory<Eigen::half> *x_backprop, DeviceMemory<float> *scale_backprop,
DeviceMemory<float> *offset_backprop,
DeviceMemory<uint8> *reserve_space_data,
ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
PARAM(scale_backprop), PARAM(offset_backprop));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoBatchNormalizationBackward(
this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
epsilon, x_backprop, scale_backprop, offset_backprop,
reserve_space_data, workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenFusedConvolveWithAlgorithm(
const dnn::BatchDescriptor &conv_input_descriptor,
const DeviceMemory<double> &conv_input_data, double conv_input_scale,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<double> &filter_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const DeviceMemory<double> &side_input_data, double side_input_scale,
const dnn::BatchDescriptor &bias_descriptor,
const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
ScratchAllocator *scratch_allocator,
const dnn::AlgorithmConfig &algorithm_config,
dnn::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
PARAM(conv_input_scale), PARAM(filter_descriptor),
PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
PARAM(side_input_data), PARAM(side_input_scale),
PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
PARAM(algorithm_config));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
auto status = dnn->DoFusedConvolve(
this, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
activation_mode, output_descriptor, output, scratch_allocator,
algorithm_config, output_profile_result);
if (!status && !output_profile_result) {
SetError();
}
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenFusedConvolveWithAlgorithm(
const dnn::BatchDescriptor &conv_input_descriptor,
const DeviceMemory<float> &conv_input_data, float conv_input_scale,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<float> &filter_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const DeviceMemory<float> &side_input_data, float side_input_scale,
const dnn::BatchDescriptor &bias_descriptor,
const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
ScratchAllocator *scratch_allocator,
const dnn::AlgorithmConfig &algorithm_config,
dnn::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
PARAM(conv_input_scale), PARAM(filter_descriptor),
PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
PARAM(side_input_data), PARAM(side_input_scale),
PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
PARAM(algorithm_config));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
auto status = dnn->DoFusedConvolve(
this, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
activation_mode, output_descriptor, output, scratch_allocator,
algorithm_config, output_profile_result);
if (!status && !output_profile_result) {
SetError();
}
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenFusedConvolveWithAlgorithm(
const dnn::BatchDescriptor &conv_input_descriptor,
const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<Eigen::half> &filter_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
const dnn::BatchDescriptor &bias_descriptor,
const DeviceMemory<Eigen::half> &biases,
dnn::ActivationMode activation_mode,
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
const dnn::AlgorithmConfig &algorithm_config,
dnn::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
PARAM(conv_input_scale), PARAM(filter_descriptor),
PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
PARAM(side_input_data), PARAM(side_input_scale),
PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
auto status = dnn->DoFusedConvolve(
this, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
activation_mode, output_descriptor, output, scratch_allocator,
algorithm_config, output_profile_result);
if (!status && !output_profile_result) {
SetError();
}
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenFusedConvolveWithAlgorithm(
const dnn::BatchDescriptor &conv_input_descriptor,
const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<int8> &filter_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const DeviceMemory<int8> &side_input_data, float side_input_scale,
const dnn::BatchDescriptor &bias_descriptor,
const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
ScratchAllocator *scratch_allocator,
const dnn::AlgorithmConfig &algorithm_config,
dnn::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
PARAM(conv_input_scale), PARAM(filter_descriptor),
PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
PARAM(side_input_data), PARAM(side_input_scale),
PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
auto status = dnn->DoFusedConvolve(
this, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
activation_mode, output_descriptor, output, scratch_allocator,
algorithm_config, output_profile_result);
if (!status && !output_profile_result) {
SetError();
}
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenFusedConvolveWithAlgorithm(
const dnn::BatchDescriptor &conv_input_descriptor,
const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<int8> &filter_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const DeviceMemory<float> &side_input_data, float side_input_scale,
const dnn::BatchDescriptor &bias_descriptor,
const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
ScratchAllocator *scratch_allocator,
const dnn::AlgorithmConfig &algorithm_config,
dnn::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
PARAM(conv_input_scale), PARAM(filter_descriptor),
PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
PARAM(side_input_data), PARAM(side_input_scale),
PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
auto status = dnn->DoFusedConvolve(
this, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
activation_mode, output_descriptor, output, scratch_allocator,
algorithm_config, output_profile_result);
if (!status && !output_profile_result) {
SetError();
}
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenConvolveWithAlgorithm(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<double> &input_data,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<double> &filter_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
ScratchAllocator *scratch_allocator,
const dnn::AlgorithmConfig &algorithm_config,
dnn::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
PARAM(filter_descriptor), PARAM(filter_data),
PARAM(convolution_descriptor), PARAM(output_descriptor),
PARAM(output), PARAM(algorithm_config));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
DeviceMemory<uint8> scratch_memory;
dnn::AlgorithmDesc algorithm_desc;
auto status =
dnn->PrepareForConvolution(
dnn::ConvolutionKind::FORWARD, this, input_descriptor,
input_data, filter_descriptor, filter_data, output_descriptor,
*output, convolution_descriptor, algorithm_config,
scratch_allocator, &algorithm_desc, &scratch_memory)
.ok();
if (status) {
status = dnn->DoConvolve(
this, input_descriptor, input_data, filter_descriptor, filter_data,
convolution_descriptor, output_descriptor, output, algorithm_desc,
&scratch_memory, output_profile_result);
}
if (!status && !output_profile_result) {
SetError();
}
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenConvolveWithAlgorithm(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<float> &input_data,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<float> &filter_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
ScratchAllocator *scratch_allocator,
const dnn::AlgorithmConfig &algorithm_config,
dnn::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
PARAM(filter_descriptor), PARAM(filter_data),
PARAM(convolution_descriptor), PARAM(output_descriptor),
PARAM(output), PARAM(algorithm_config));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
DeviceMemory<uint8> scratch_memory;
dnn::AlgorithmDesc algorithm_desc;
auto status =
dnn->PrepareForConvolution(
dnn::ConvolutionKind::FORWARD, this, input_descriptor,
input_data, filter_descriptor, filter_data, output_descriptor,
*output, convolution_descriptor, algorithm_config,
scratch_allocator, &algorithm_desc, &scratch_memory)
.ok();
if (status) {
status = dnn->DoConvolve(
this, input_descriptor, input_data, filter_descriptor, filter_data,
convolution_descriptor, output_descriptor, output, algorithm_desc,
&scratch_memory, output_profile_result);
}
if (!status && !output_profile_result) {
SetError();
}
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenConvolveWithAlgorithm(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<Eigen::half> &input_data,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<Eigen::half> &filter_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
const dnn::AlgorithmConfig &algorithm_config,
dnn::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
PARAM(filter_descriptor), PARAM(filter_data),
PARAM(convolution_descriptor), PARAM(output_descriptor),
PARAM(output), PARAM(algorithm_config));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
DeviceMemory<uint8> scratch_memory;
dnn::AlgorithmDesc algorithm_desc;
auto status =
dnn->PrepareForConvolution(
dnn::ConvolutionKind::FORWARD, this, input_descriptor,
input_data, filter_descriptor, filter_data, output_descriptor,
*output, convolution_descriptor, algorithm_config,
scratch_allocator, &algorithm_desc, &scratch_memory)
.ok();
if (status) {
status = dnn->DoConvolve(
this, input_descriptor, input_data, filter_descriptor, filter_data,
convolution_descriptor, output_descriptor, output, algorithm_desc,
&scratch_memory, output_profile_result);
}
if (!status && !output_profile_result) {
SetError();
}
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenConvolveWithAlgorithm(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<int8> &input_data,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<int8> &filter_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
ScratchAllocator *scratch_allocator,
const dnn::AlgorithmConfig &algorithm_config,
dnn::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
PARAM(filter_descriptor), PARAM(filter_data),
PARAM(convolution_descriptor), PARAM(output_descriptor),
PARAM(output), PARAM(algorithm_config));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
DeviceMemory<uint8> scratch_memory;
dnn::AlgorithmDesc algorithm_desc;
auto status =
dnn->PrepareForConvolution(
dnn::ConvolutionKind::FORWARD, this, input_descriptor,
input_data, filter_descriptor, filter_data, output_descriptor,
*output, convolution_descriptor, algorithm_config,
scratch_allocator, &algorithm_desc, &scratch_memory)
.ok();
if (status) {
status = dnn->DoConvolve(
this, input_descriptor, input_data, filter_descriptor, filter_data,
convolution_descriptor, output_descriptor, output, algorithm_desc,
&scratch_memory, output_profile_result);
}
if (!status && !output_profile_result) {
SetError();
}
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenConvolveWithAlgorithm(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<int8> &input_data,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<int8> &filter_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
ScratchAllocator *scratch_allocator,
const dnn::AlgorithmConfig &algorithm_config,
dnn::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
PARAM(filter_descriptor), PARAM(filter_data),
PARAM(convolution_descriptor), PARAM(output_descriptor),
PARAM(output), PARAM(algorithm_config));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
DeviceMemory<uint8> scratch_memory;
dnn::AlgorithmDesc algorithm_desc;
auto status =
dnn->PrepareForConvolution(
dnn::ConvolutionKind::FORWARD, this, input_descriptor,
input_data, filter_descriptor, filter_data, output_descriptor,
*output, convolution_descriptor, algorithm_config,
scratch_allocator, &algorithm_desc, &scratch_memory)
.ok();
if (status) {
status = dnn->DoConvolve(
this, input_descriptor, input_data, filter_descriptor, filter_data,
convolution_descriptor, output_descriptor, output, algorithm_desc,
&scratch_memory, output_profile_result);
}
if (!status && !output_profile_result) {
SetError();
}
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenConvolve(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<float> &input_data,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<float> &filter_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<float> *output) {
return ThenConvolveWithAlgorithm(
input_descriptor, input_data, filter_descriptor, filter_data,
convolution_descriptor, output_descriptor, output,
/*scratch_allocator=*/nullptr, dnn::AlgorithmConfig(),
/*output_profile_result=*/nullptr);
}
Stream &Stream::ThenConvolveQuantized(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<float> &input_data,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<int8> &filter_coefficients,
const DeviceMemory<float> &coefficient_scales,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<float> *output) {
VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
PARAM(filter_descriptor), PARAM(filter_coefficients),
PARAM(coefficient_scales), PARAM(convolution_descriptor),
PARAM(output_descriptor), PARAM(output));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoConvolveQuantized(
this, input_descriptor, input_data, filter_descriptor,
filter_coefficients, coefficient_scales, convolution_descriptor,
output_descriptor, output));
} else {
SetError();
LOG(WARNING)
<< "attempting to perform DNN operation using StreamExecutor "
"without DNN support";
}
}
return *this;
}
Stream &Stream::ThenConvolveQuantized(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<float> &input_data,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<int16> &filter_coefficients,
const DeviceMemory<float> &coefficient_scales,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<float> *output) {
VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
PARAM(filter_descriptor), PARAM(filter_coefficients),
PARAM(coefficient_scales), PARAM(convolution_descriptor),
PARAM(output_descriptor), PARAM(output));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoConvolveQuantized(
this, input_descriptor, input_data, filter_descriptor,
filter_coefficients, coefficient_scales, convolution_descriptor,
output_descriptor, output));
} else {
SetError();
LOG(WARNING)
<< "attempting to perform DNN operation using StreamExecutor "
"without DNN support";
}
}
return *this;
}
Stream &Stream::ThenSeparableConvolve(
const dnn::BatchDescriptor &batch_descriptor,
const DeviceMemory<float> &input_data,
const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
const DeviceMemory<float> &first_weights,
const DeviceMemory<float> &second_weights,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<float> *output) {
VLOG_CALL(
PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor),
PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights),
PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoSeparableConvolve(
this, batch_descriptor, input_data, filter_descriptor,
depth_multiplier, first_weights, second_weights,
convolution_descriptor, output_descriptor, output));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<double> &filter_data,
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<double> backward_output_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::BatchDescriptor &input_descriptor,
DeviceMemory<double> *backward_input_data,
ScratchAllocator *scratch_allocator,
const dnn::AlgorithmConfig &algorithm_config,
dnn::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
PARAM(output_descriptor), PARAM(backward_output_data),
PARAM(convolution_descriptor), PARAM(input_descriptor),
PARAM(backward_input_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
DeviceMemory<uint8> scratch_memory;
dnn::AlgorithmDesc algorithm_desc;
auto status =
dnn->PrepareForConvolution(
dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
*backward_input_data, filter_descriptor, filter_data,
output_descriptor, backward_output_data,
convolution_descriptor, algorithm_config, scratch_allocator,
&algorithm_desc, &scratch_memory)
.ok();
if (status) {
status = dnn->DoConvolveBackwardData(
this, filter_descriptor, filter_data, output_descriptor,
backward_output_data, convolution_descriptor, input_descriptor,
backward_input_data, algorithm_desc, &scratch_memory,
output_profile_result);
}
if (!status && !output_profile_result) {
SetError();
}
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<float> &filter_data,
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<float> backward_output_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::BatchDescriptor &input_descriptor,
DeviceMemory<float> *backward_input_data,
ScratchAllocator *scratch_allocator,
const dnn::AlgorithmConfig &algorithm_config,
dnn::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
PARAM(output_descriptor), PARAM(backward_output_data),
PARAM(convolution_descriptor), PARAM(input_descriptor),
PARAM(backward_input_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
DeviceMemory<uint8> scratch_memory;
dnn::AlgorithmDesc algorithm_desc;
auto status =
dnn->PrepareForConvolution(
dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
*backward_input_data, filter_descriptor, filter_data,
output_descriptor, backward_output_data,
convolution_descriptor, algorithm_config, scratch_allocator,
&algorithm_desc, &scratch_memory)
.ok();
if (status) {
status = dnn->DoConvolveBackwardData(
this, filter_descriptor, filter_data, output_descriptor,
backward_output_data, convolution_descriptor, input_descriptor,
backward_input_data, algorithm_desc, &scratch_memory,
output_profile_result);
}
if (!status && !output_profile_result) {
SetError();
}
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<Eigen::half> &filter_data,
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<Eigen::half> backward_output_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::BatchDescriptor &input_descriptor,
DeviceMemory<Eigen::half> *backward_input_data,
ScratchAllocator *scratch_allocator,
const dnn::AlgorithmConfig &algorithm_config,
dnn::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
PARAM(output_descriptor), PARAM(backward_output_data),
PARAM(convolution_descriptor), PARAM(input_descriptor),
PARAM(backward_input_data), PARAM(algorithm_config));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
DeviceMemory<uint8> scratch_memory;
dnn::AlgorithmDesc algorithm_desc;
auto status =
dnn->PrepareForConvolution(
dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
*backward_input_data, filter_descriptor, filter_data,
output_descriptor, backward_output_data,
convolution_descriptor, algorithm_config, scratch_allocator,
&algorithm_desc, &scratch_memory)
.ok();
if (status) {
status = dnn->DoConvolveBackwardData(
this, filter_descriptor, filter_data, output_descriptor,
backward_output_data, convolution_descriptor, input_descriptor,
backward_input_data, algorithm_desc, &scratch_memory,
output_profile_result);
}
if (!status && !output_profile_result) {
SetError();
}
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<double> &input_data,
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<double> backward_output_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::FilterDescriptor &filter_descriptor,
DeviceMemory<double> *backward_filter_data,
ScratchAllocator *scratch_allocator,
const dnn::AlgorithmConfig &algorithm_config,
dnn::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
PARAM(output_descriptor), PARAM(backward_output_data),
PARAM(convolution_descriptor), PARAM(filter_descriptor),
PARAM(backward_filter_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
DeviceMemory<uint8> scratch_memory;
dnn::AlgorithmDesc algorithm_desc;
auto status =
dnn->PrepareForConvolution(
dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
input_data, filter_descriptor, *backward_filter_data,
output_descriptor, backward_output_data,
convolution_descriptor, algorithm_config, scratch_allocator,
&algorithm_desc, &scratch_memory)
.ok();
if (status) {
status = dnn->DoConvolveBackwardFilter(
this, input_descriptor, input_data, output_descriptor,
backward_output_data, convolution_descriptor, filter_descriptor,
backward_filter_data, algorithm_desc, &scratch_memory,
output_profile_result);
}
if (!status && !output_profile_result) {
SetError();
}
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<float> &input_data,
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<float> backward_output_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::FilterDescriptor &filter_descriptor,
DeviceMemory<float> *backward_filter_data,
ScratchAllocator *scratch_allocator,
const dnn::AlgorithmConfig &algorithm_config,
dnn::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
PARAM(output_descriptor), PARAM(backward_output_data),
PARAM(convolution_descriptor), PARAM(filter_descriptor),
PARAM(backward_filter_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
DeviceMemory<uint8> scratch_memory;
dnn::AlgorithmDesc algorithm_desc;
auto status =
dnn->PrepareForConvolution(
dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
input_data, filter_descriptor, *backward_filter_data,
output_descriptor, backward_output_data,
convolution_descriptor, algorithm_config, scratch_allocator,
&algorithm_desc, &scratch_memory)
.ok();
if (status) {
status = dnn->DoConvolveBackwardFilter(
this, input_descriptor, input_data, output_descriptor,
backward_output_data, convolution_descriptor, filter_descriptor,
backward_filter_data, algorithm_desc, &scratch_memory,
output_profile_result);
}
if (!status && !output_profile_result) {
SetError();
}
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<Eigen::half> &input_data,
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<Eigen::half> backward_output_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::FilterDescriptor &filter_descriptor,
DeviceMemory<Eigen::half> *backward_filter_data,
ScratchAllocator *scratch_allocator,
const dnn::AlgorithmConfig &algorithm_config,
dnn::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
PARAM(output_descriptor), PARAM(backward_output_data),
PARAM(convolution_descriptor), PARAM(filter_descriptor),
PARAM(backward_filter_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
DeviceMemory<uint8> scratch_memory;
dnn::AlgorithmDesc algorithm_desc;
auto status =
dnn->PrepareForConvolution(
dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
input_data, filter_descriptor, *backward_filter_data,
output_descriptor, backward_output_data,
convolution_descriptor, algorithm_config, scratch_allocator,
&algorithm_desc, &scratch_memory)
.ok();
if (status) {
status = dnn->DoConvolveBackwardFilter(
this, input_descriptor, input_data, output_descriptor,
backward_output_data, convolution_descriptor, filter_descriptor,
backward_filter_data, algorithm_desc, &scratch_memory,
output_profile_result);
}
if (!status && !output_profile_result) {
SetError();
}
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
template <typename T>
Stream &Stream::ThenConvolveBackwardBiasImpl(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<T> &input_data,
const dnn::BatchDescriptor &bias_descriptor,
DeviceMemory<T> *backward_bias_data) {
VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(bias_descriptor),
PARAM(backward_bias_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoConvolveBackwardBias(this, input_descriptor, input_data,
bias_descriptor,
backward_bias_data));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenConvolveBackwardBias(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<double> &input_data,
const dnn::BatchDescriptor &bias_descriptor,
DeviceMemory<double> *backward_bias_data) {
return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
bias_descriptor, backward_bias_data);
}
Stream &Stream::ThenConvolveBackwardBias(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<float> &input_data,
const dnn::BatchDescriptor &bias_descriptor,
DeviceMemory<float> *backward_bias_data) {
return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
bias_descriptor, backward_bias_data);
}
Stream &Stream::ThenConvolveBackwardBias(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<Eigen::half> &input_data,
const dnn::BatchDescriptor &bias_descriptor,
DeviceMemory<Eigen::half> *backward_bias_data) {
return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
bias_descriptor, backward_bias_data);
}
Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data,
const DeviceMemory<float> &weights,
const dnn::BatchDescriptor &input_dimensions,
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions),
PARAM(output_dimensions), PARAM(output_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions,
output_dimensions, output_data));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenMatMulQuantized(
const DeviceMemory<float> &input_data, const DeviceMemory<int8> &weights,
const DeviceMemory<float> &weight_scales,
const dnn::BatchDescriptor &input_dimensions,
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
PARAM(input_dimensions), PARAM(output_dimensions),
PARAM(output_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
weight_scales, input_dimensions,
output_dimensions, output_data));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenMatMulQuantized(
const DeviceMemory<float> &input_data, const DeviceMemory<int16> &weights,
const DeviceMemory<float> &weight_scales,
const dnn::BatchDescriptor &input_dimensions,
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
PARAM(input_dimensions), PARAM(output_dimensions),
PARAM(output_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
weight_scales, input_dimensions,
output_dimensions, output_data));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data,
const DeviceMemory<float> &biases,
const dnn::BatchDescriptor &dimensions,
DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions),
PARAM(output_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(
dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenPoolForward(
const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<double> &input_data,
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<double> *output_data, ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
input_data, output_dimensions, output_data,
workspace_allocator));
} else {
SetError();
LOG(WARNING)
<< "attempting to perform DNN operation using StreamExecutor "
"without DNN support";
}
}
return *this;
}
Stream &Stream::ThenPoolForward(
const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<float> &input_data,
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<float> *output_data, ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
input_data, output_dimensions, output_data,
workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenPoolForward(
const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<Eigen::half> &input_data,
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<Eigen::half> *output_data,
ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
input_data, output_dimensions, output_data,
workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenPoolForward(
const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<int8> &input_data,
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<int8> *output_data, ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
input_data, output_dimensions, output_data,
workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenPoolBackward(
const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<double> &input_data,
const dnn::BatchDescriptor &output_dimensions,
const DeviceMemory<double> &output_data,
const DeviceMemory<double> &input_diff_data,
DeviceMemory<double> *output_diff_data,
ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
PARAM(input_diff_data), PARAM(output_diff_data),
PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
input_data, output_dimensions, output_data,
input_diff_data, output_diff_data,
workspace_allocator));
} else {
SetError();
LOG(WARNING)
<< "attempting to perform DNN operation using StreamExecutor "
"without DNN support";
}
}
return *this;
}
Stream &Stream::ThenPoolBackward(
const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<float> &input_data,
const dnn::BatchDescriptor &output_dimensions,
const DeviceMemory<float> &output_data,
const DeviceMemory<float> &input_diff_data,
DeviceMemory<float> *output_diff_data,
ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
PARAM(input_diff_data), PARAM(output_diff_data),
PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
input_data, output_dimensions, output_data,
input_diff_data, output_diff_data,
workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenPoolBackward(
const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<Eigen::half> &input_data,
const dnn::BatchDescriptor &output_dimensions,
const DeviceMemory<Eigen::half> &output_data,
const DeviceMemory<Eigen::half> &input_diff_data,
DeviceMemory<Eigen::half> *output_diff_data,
ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
PARAM(input_diff_data), PARAM(output_diff_data),
PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
input_data, output_dimensions, output_data,
input_diff_data, output_diff_data,
workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenNormalizeWithDimensions(
const dnn::NormalizeDescriptor &normalize_descriptor,
const dnn::BatchDescriptor &dimensions,
const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data),
PARAM(output_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoNormalizeWithDimensions(
this, normalize_descriptor, dimensions, input_data, output_data));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenNormalizeBackwardWithDimensions(
const dnn::NormalizeDescriptor &normalize_descriptor,
const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data,
const DeviceMemory<float> &normalized_data,
const DeviceMemory<float> &normalized_variable_gradient,
DeviceMemory<float> *raw_variable_gradient,
ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data),
PARAM(normalized_data), PARAM(normalized_variable_gradient),
PARAM(raw_variable_gradient), PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoNormalizeBackwardWithDimensions(
this, normalize_descriptor, dimensions, raw_data, normalized_data,
normalized_variable_gradient, raw_variable_gradient,
workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode,
const dnn::BatchDescriptor &dimensions,
const DeviceMemory<float> &input_data,
DeviceMemory<float> *output_data) {
return ThenActivateWithOptions(activation_mode, dimensions, input_data,
output_data, /*options=*/0);
}
Stream &Stream::ThenActivateWithOptions(dnn::ActivationMode activation_mode,
const dnn::BatchDescriptor &dimensions,
const DeviceMemory<float> &input_data,
DeviceMemory<float> *output_data,
uint64 options) {
VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data),
PARAM(output_data), PARAM(options));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data,
output_data, options));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenDepthConcatenate(
port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
port::ArraySlice<const DeviceMemory<float> *> input_data,
DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
for (size_t i = 1; i < input_dimensions.size(); ++i) {
if (input_dimensions[i].count() != input_dimensions[0].count() ||
input_dimensions[i].height() != input_dimensions[0].height() ||
input_dimensions[i].width() != input_dimensions[0].width()) {
SetError();
LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n"
<< "input_dimensions[0]: " << input_dimensions[0].ToString()
<< "input_dimensions[" << i
<< "]: " << input_dimensions[i].ToString();
return *this;
}
}
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data,
output_data));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenSpaceConcatenate(
port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
port::ArraySlice<const DeviceMemory<float> *> input_data,
DeviceMemory<float> *output_data,
dnn::SpaceConcatenateMode concat_direction) {
VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
// Check that the input dimensions of all the other batches match those of the
// first batch.
for (size_t i = 1; i < input_dimensions.size(); ++i) {
if ((concat_direction == dnn::SpaceConcatenateMode::XDirection) &&
(input_dimensions[i].count() != input_dimensions[0].count() ||
input_dimensions[i].height() != input_dimensions[0].height() ||
input_dimensions[i].feature_map_count() !=
input_dimensions[0].feature_map_count())) {
SetError();
LOG(ERROR) << "Incompatible dimensions for X concatenation.\n"
<< "input_dimensions[0]: " << input_dimensions[0].ToString()
<< "input_dimensions[" << i
<< "]: " << input_dimensions[i].ToString();
return *this;
}
if ((concat_direction == dnn::SpaceConcatenateMode::YDirection) &&
(input_dimensions[i].count() != input_dimensions[0].count() ||
input_dimensions[i].width() != input_dimensions[0].width() ||
input_dimensions[i].feature_map_count() !=
input_dimensions[0].feature_map_count())) {
SetError();
LOG(ERROR) << "Incompatible dimensions for Y concatenation.\n"
<< "input_dimensions[0]: " << input_dimensions[0].ToString()
<< "input_dimensions[" << i
<< "]: " << input_dimensions[i].ToString();
return *this;
}
}
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoSpaceConcatenate(this, input_dimensions, input_data,
output_data, concat_direction));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenReshape(const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<float> &input_data,
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
PARAM(output_dimensions), PARAM(output_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoReshape(this, input_dimensions, input_data,
output_dimensions, output_data));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenDepthToSpace(
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<float> &input_data,
const dnn::DepthToSpaceLayout &depth_to_space_layout,
const int sqrt_depth_reduction, DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
PARAM(depth_to_space_layout), PARAM(sqrt_depth_reduction),
PARAM(output_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoDepthToSpace(this, input_dimensions, input_data,
depth_to_space_layout,
sqrt_depth_reduction, output_data));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenSpaceToDepth(
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<float> &input_data,
const dnn::DepthToSpaceLayout &space_to_depth_layout,
const int sqrt_depth_increase, DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
PARAM(space_to_depth_layout), PARAM(sqrt_depth_increase),
PARAM(output_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoSpaceToDepth(this, input_dimensions, input_data,
space_to_depth_layout, sqrt_depth_increase,
output_data));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenElementwiseOperate(
dnn::ElementwiseOperation operation,
port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
port::ArraySlice<const DeviceMemory<float> *> input_data,
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data),
PARAM(output_dimensions), PARAM(output_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions,
input_data, output_dimensions,
output_data));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenElementwiseOperateScaledQuantized(
dnn::ElementwiseOperation operation,
port::ArraySlice<int> input_multiplicands, int output_divisor,
port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
port::ArraySlice<const DeviceMemory<float> *> input_data,
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(operation), PARAM(input_multiplicands), PARAM(output_divisor),
PARAM(input_dimensions), PARAM(input_data),
PARAM(output_dimensions), PARAM(output_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoElementwiseOperateScaledQuantized(
this, operation, input_multiplicands, output_divisor,
input_dimensions, input_data, output_dimensions, output_data));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions,
const DeviceMemory<float> &input_data, int64 left_pad,
int64 right_pad, int64 top_pad, int64 bottom_pad,
DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad),
PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad),
PARAM(output_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad,
top_pad, bottom_pad, output_data));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions,
const DeviceMemory<float> &input_data,
int64 left_trim, int64 right_trim, int64 top_trim,
int64 bottom_trim,
DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim),
PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim),
PARAM(output_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim,
right_trim, top_trim, bottom_trim,
output_data));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
const DeviceMemory<float> &input_data,
int64 replicate_x, int64 replicate_y,
DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(replicate_x),
PARAM(replicate_y), PARAM(output_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoXYBroadcast(this, dimensions, input_data, replicate_x,
replicate_y, output_data));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenMemcpyD2HQuantized(
const DeviceMemory<float> &gpu_unquantized_src,
dnn::QuantizedActivationMode mode, void *host_dst, uint64 size) {
VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst),
PARAM(size));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode,
host_dst, size));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenMemcpyH2DQuantized(
const void *host_src, uint64 size, dnn::QuantizedActivationMode mode,
DeviceMemory<float> *gpu_unquantized_dst) {
VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode),
PARAM(gpu_unquantized_dst));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode,
gpu_unquantized_dst));
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream *Stream::GetOrCreateSubStream() {
absl::MutexLock lock(&mu_);
// Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
// we encounter along the way.
for (int64 index = 0; index < sub_streams_.size();) {
std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
if (pair.second) {
// The sub_stream is reusable.
Stream *sub_stream = pair.first.get();
if (sub_stream->ok()) {
VLOG(1) << DebugStreamPointers() << " reusing sub_stream "
<< sub_stream->DebugStreamPointers();
pair.second = false;
return sub_stream;
}
// The stream is reusable and not ok. Streams have a monotonic state
// machine; the stream will remain in !ok forever. Swap it with the last
// stream and pop it off.
const int64 last = sub_streams_.size() - 1;
if (index != last) {
std::swap(pair, sub_streams_[last]);
}
sub_streams_.pop_back();
VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream "
<< sub_stream->DebugStreamPointers();
} else {
// The sub_stream is not reusable, move on to the next one.
++index;
}
}
// No streams are reusable; create a new stream.
sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
false);
Stream *sub_stream = sub_streams_.back().first.get();
sub_stream->Init();
if (!sub_stream->ok_) {
LOG(ERROR) << "sub-stream failed to be initialized";
}
VLOG(1) << DebugStreamPointers() << " created new sub_stream "
<< sub_stream->DebugStreamPointers();
return sub_stream;
}
void Stream::ReturnSubStream(Stream *sub_stream) {
absl::MutexLock lock(&mu_);
// Look for the sub-stream.
for (int64 index = 0; index < sub_streams_.size(); ++index) {
std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
if (pair.first.get() != sub_stream) {
continue;
}
// Found the sub_stream.
if (sub_stream->ok()) {
VLOG(1) << DebugStreamPointers() << " returned ok sub_stream "
<< sub_stream->DebugStreamPointers();
pair.second = true;
} else {
// The returned stream is not ok. Streams have a monotonic state
// machine; the stream will remain in !ok forever. Swap it with the last
// stream and pop it off.
VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream "
<< sub_stream->DebugStreamPointers();
const int64 last = sub_streams_.size() - 1;
if (index != last) {
std::swap(pair, sub_streams_[last]);
}
sub_streams_.pop_back();
}
return;
}
LOG(FATAL) << DebugStreamPointers()
<< " did not create the returned sub-stream "
<< sub_stream->DebugStreamPointers();
}
Stream &Stream::ThenStartTimer(Timer *t) {
VLOG_CALL(PARAM(t));
if (ok()) {
CheckError(parent_->StartTimer(this, t));
} else {
LOG(INFO) << DebugStreamPointers()
<< " did not enqueue 'start timer': " << t;
}
return *this;
}
Stream &Stream::ThenStopTimer(Timer *t) {
VLOG_CALL(PARAM(t));
if (ok()) {
CheckError(parent_->StopTimer(this, t));
} else {
LOG(INFO) << DebugStreamPointers()
<< " did not enqueue 'stop timer': " << t;
}
return *this;
}
Stream &Stream::ThenWaitFor(Stream *other) {
VLOG_CALL(PARAM(other));
CHECK(this != other) << "stream cannot wait for itself";
if (ok() && other->ok()) {
CheckError(parent_->CreateStreamDependency(this, other));
} else {
SetError();
LOG(INFO) << DebugStreamPointers() << " did not wait for "
<< other->DebugStreamPointers();
}
return *this;
}
Stream &Stream::ThenWaitFor(Event *event) {
VLOG_CALL(PARAM(event));
if (ok()) {
port::Status status = parent_->WaitForEvent(this, event);
if (!status.ok()) {
LOG(ERROR) << "Error waiting for event in stream: "
<< status.error_message()
<< "; not marking stream as bad, as the Event object may be "
<< "at fault. Monitor for further errors.";
}
} else {
LOG(INFO) << DebugStreamPointers() << " did not wait for an event.";
}
return *this;
}
// A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX
// functions and logs for errors.
template <typename... Args>
struct ThenBlasImpl {
// blas_func is the DoBlasXXX member function pointer, and args are its
// arguments except the first one of Stream* type.
Stream &operator()(Stream *stream,
bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
Args... args) {
return Run(stream, blas_func, /*record_error=*/true, args...);
}
// Like operator(), but only calls stream->CheckError() if record_error is
// true.
Stream &Run(Stream *stream,
bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
bool record_error, Args... args);
};
template <typename... Args>
Stream &ThenBlasImpl<Args...>::Run(
Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
bool record_error, Args... args) {
if (stream->ok()) {
bool ok;
if (blas::BlasSupport *blas = stream->parent_->AsBlas()) {
ok = (blas->*blas_func)(stream, args...);
} else {
LOG(WARNING)
<< "attempting to perform BLAS operation using StreamExecutor "
"without BLAS support";
ok = false;
}
if (record_error) {
stream->CheckError(ok);
}
}
return *stream;
}
Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
int incx, DeviceMemory<float> *result) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
impl;
return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
result);
}
Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
int incx, DeviceMemory<double> *result) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
DeviceMemory<double> *> impl;
return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
result);
}
Stream &Stream::ThenBlasAsum(uint64 elem_count,
const DeviceMemory<std::complex<float>> &x,
int incx, DeviceMemory<float> *result) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
DeviceMemory<float> *> impl;
return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
result);
}
Stream &Stream::ThenBlasAsum(uint64 elem_count,
const DeviceMemory<std::complex<double>> &x,
int incx, DeviceMemory<double> *result) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
DeviceMemory<double> *> impl;
return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
result);
}
Stream &Stream::ThenBlasAxpy(uint64 elem_count, float alpha,
const DeviceMemory<float> &x, int incx,
DeviceMemory<float> *y, int incy) {
VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
PARAM(incy));
ThenBlasImpl<uint64, float, const DeviceMemory<float> &, int,
DeviceMemory<float> *, int> impl;
return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
y, incy);
}
Stream &Stream::ThenBlasAxpy(uint64 elem_count, double alpha,
const DeviceMemory<double> &x, int incx,
DeviceMemory<double> *y, int incy) {
VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
PARAM(incy));
ThenBlasImpl<uint64, double, const DeviceMemory<double> &, int,
DeviceMemory<double> *, int> impl;
return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
y, incy);
}
Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &x,
int incx, DeviceMemory<std::complex<float>> *y,
int incy) {
VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
PARAM(incy));
ThenBlasImpl<uint64, std::complex<float>,
const DeviceMemory<std::complex<float>> &, int,
DeviceMemory<std::complex<float>> *, int> impl;
return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
y, incy);
}
Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &x,
int incx, DeviceMemory<std::complex<double>> *y,
int incy) {
VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
PARAM(incy));
ThenBlasImpl<uint64, std::complex<double>,
const DeviceMemory<std::complex<double>> &, int,
DeviceMemory<std::complex<double>> *, int> impl;
return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
y, incy);
}
Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
int incx, DeviceMemory<float> *y, int incy) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
int> impl;
return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
incy);
}
Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
int incx, DeviceMemory<double> *y, int incy) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
DeviceMemory<double> *, int> impl;
return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
incy);
}
Stream &Stream::ThenBlasCopy(uint64 elem_count,
const DeviceMemory<std::complex<float>> &x,
int incx, DeviceMemory<std::complex<float>> *y,
int incy) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
DeviceMemory<std::complex<float>> *, int> impl;
return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
incy);
}
Stream &Stream::ThenBlasCopy(uint64 elem_count,
const DeviceMemory<std::complex<double>> &x,
int incx, DeviceMemory<std::complex<double>> *y,
int incy) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
DeviceMemory<std::complex<double>> *, int> impl;
return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
incy);
}
Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x,
int incx, const DeviceMemory<float> &y, int incy,
DeviceMemory<float> *result) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
PARAM(result));
ThenBlasImpl<uint64, const DeviceMemory<float> &, int,
const DeviceMemory<float> &, int, DeviceMemory<float> *> impl;
return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
result);
}
Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
int incx, const DeviceMemory<double> &y, int incy,
DeviceMemory<double> *result) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
PARAM(result));
ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
const DeviceMemory<double> &, int, DeviceMemory<double> *> impl;
return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
result);
}
Stream &Stream::ThenBlasDotc(uint64 elem_count,
const DeviceMemory<std::complex<float>> &x,
int incx,
const DeviceMemory<std::complex<float>> &y,
int incy,
DeviceMemory<std::complex<float>> *result) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
PARAM(result));
ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
const DeviceMemory<std::complex<float>> &, int,
DeviceMemory<std::complex<float>> *> impl;
return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
incy, result);
}
Stream &Stream::ThenBlasDotc(uint64 elem_count,
const DeviceMemory<std::complex<double>> &x,
int incx,
const DeviceMemory<std::complex<double>> &y,
int incy,
DeviceMemory<std::complex<double>> *result) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
PARAM(result));
ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
const DeviceMemory<std::complex<double>> &, int,
DeviceMemory<std::complex<double>> *> impl;
return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
incy, result);
}
Stream &Stream::ThenBlasDotu(uint64 elem_count,
const DeviceMemory<std::complex<float>> &x,
int incx,
const DeviceMemory<std::complex<float>> &y,
int incy,
DeviceMemory<std::complex<float>> *result) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
PARAM(result));
ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
const DeviceMemory<std::complex<float>> &, int,
DeviceMemory<std::complex<float>> *> impl;
return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
incy, result);
}
Stream &Stream::ThenBlasDotu(uint64 elem_count,
const DeviceMemory<std::complex<double>> &x,
int incx,
const DeviceMemory<std::complex<double>> &y,
int incy,
DeviceMemory<std::complex<double>> *result) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
PARAM(result));
ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
const DeviceMemory<std::complex<double>> &, int,
DeviceMemory<std::complex<double>> *> impl;
return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
incy, result);
}
Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
int incx, DeviceMemory<float> *result) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
impl;
return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
result);
}
Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
int incx, DeviceMemory<double> *result) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
DeviceMemory<double> *> impl;
return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
result);
}
Stream &Stream::ThenBlasNrm2(uint64 elem_count,
const DeviceMemory<std::complex<float>> &x,
int incx, DeviceMemory<float> *result) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
DeviceMemory<float> *> impl;
return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
result);
}
Stream &Stream::ThenBlasNrm2(uint64 elem_count,
const DeviceMemory<std::complex<double>> &x,
int incx, DeviceMemory<double> *result) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
DeviceMemory<double> *> impl;
return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
result);
}
Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
DeviceMemory<float> *y, int incy, float c,
float s) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
PARAM(c), PARAM(s));
ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
float, float> impl;
return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
c, s);
}
Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x,
int incx, DeviceMemory<double> *y, int incy,
double c, double s) {
VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
PARAM(c), PARAM(s));
ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
double, double> impl;
return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
c, s);
}
Stream &Stream::ThenBlasRot(uint64