blob: 0ac5af04582bb18f083ae1bc4914a6487b0aef33 [file] [log] [blame]
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#endif
#include "tensorflow/core/kernels/fused_batch_norm_op.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
namespace functor {
// TODO(ezhulenev): Use CUB reductions on GPU.
template <typename T, typename U>
struct FusedBatchNormFreezeGrad<GPUDevice, T, U> {
void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
const Tensor& x_input, const Tensor& scale_input,
const Tensor& pop_mean_input,
const Tensor& pop_variance_input, U epsilon,
Tensor* x_backprop_output, Tensor* scale_backprop_output,
Tensor* offset_backprop_output) {
typename TTypes<T, 4>::ConstTensor y_backprop(
y_backprop_input.tensor<T, 4>());
typename TTypes<T, 4>::ConstTensor input(x_input.tensor<T, 4>());
typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
typename TTypes<U>::ConstVec pop_mean(pop_mean_input.vec<U>());
typename TTypes<U>::ConstVec pop_var(pop_variance_input.vec<U>());
typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>());
typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>());
typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>());
const int depth = pop_mean.dimension(0);
const int rest_size = input.size() / depth;
// Allocate two temporary workspaces of [depth] shape.
Tensor scratch1_vec, scratch2_vec;
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
{depth}, &scratch1_vec));
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
{depth}, &scratch2_vec));
typename TTypes<U>::Vec scratch1(scratch1_vec.vec<U>());
typename TTypes<U>::Vec scratch2(scratch2_vec.vec<U>());
const GPUDevice& d = context->eigen_device<GPUDevice>();
Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
#if !defined(EIGEN_HAS_INDEX_LIST)
Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
Eigen::array<int, 1> reduction_axis{0};
Eigen::array<int, 2> rest_by_one({rest_size, 1});
#else
Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
one_by_depth.set(1, depth);
Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > rest_by_one;
rest_by_one.set(0, rest_size);
#endif
// offset_backprop = sum(y_backprop)
// scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon))
// x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon))
auto y_backprop_rest_by_depth =
y_backprop.reshape(rest_by_depth).template cast<U>();
auto input_rest_by_depth = input.reshape(rest_by_depth).template cast<U>();
offset_backprop.device(d) = y_backprop_rest_by_depth.sum(reduction_axis);
// scratch1 = rsqrt(pop_var + epsilon)
scratch1.device(d) = (pop_var + pop_var.constant(epsilon)).rsqrt();
// scratch2 = sum(y_backprop * (x - mean))
scratch2.device(d) =
(y_backprop_rest_by_depth *
(input_rest_by_depth -
pop_mean.reshape(one_by_depth).broadcast(rest_by_one)))
.sum(reduction_axis);
x_backprop.reshape(rest_by_depth).device(d) =
(y_backprop_rest_by_depth *
((scratch1 * scale).reshape(one_by_depth).broadcast(rest_by_one)))
.template cast<T>();
scale_backprop.device(d) = scratch2 * scratch1;
}
};
template struct FusedBatchNormFreezeGrad<GPUDevice, float, float>;
template struct FusedBatchNormFreezeGrad<GPUDevice, Eigen::half, float>;
template <class T>
__global__ void VarianceToInvVarianceKernel(int nthreads, const T* input,
double epsilon, T* output) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
output[index] = rsqrt(input[index] + T(epsilon));
}
}
template <class T>
void VarianceToInvVariance<T>::operator()(const Eigen::GpuDevice& d,
const T* variance, double epsilon,
int channels, T* inv_variance) {
GpuLaunchConfig config = GetGpuLaunchConfig(channels, d);
TF_CHECK_OK(GpuLaunchKernel(VarianceToInvVarianceKernel<T>,
config.block_count, config.thread_per_block, 0,
d.stream(), config.virtual_thread_count, variance,
epsilon, inv_variance));
}
template <class T>
__global__ void InvVarianceToVarianceKernel(int nthreads, double epsilon,
int sample_size, T* variance) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
T inv_var = variance[index];
T var = __fdividef(1, inv_var * inv_var) - T(epsilon);
// This is for Bessel's correction
var *= T(sample_size) / T((sample_size > 1) ? sample_size - 1 : 1);
variance[index] = (var > 0) ? var : 0;
}
}
template <class T>
void InvVarianceToVariance<T>::operator()(const Eigen::GpuDevice& d,
double epsilon, int sample_size,
int channels, T* variance) {
GpuLaunchConfig config = GetGpuLaunchConfig(channels, d);
TF_CHECK_OK(GpuLaunchKernel(InvVarianceToVarianceKernel<T>,
config.block_count, config.thread_per_block, 0,
d.stream(), config.virtual_thread_count, epsilon,
sample_size, variance));
}
template <class T>
void SetNanFunctor<T>::operator()(const Eigen::GpuDevice& d,
typename TTypes<T>::Flat out) {
To32Bit(out).device(d) =
To32Bit(out).constant(Eigen::NumTraits<T>::quiet_NaN());
}
template class VarianceToInvVariance<float>;
template class InvVarianceToVariance<float>;
template class SetNanFunctor<float>;
// -------------------------------------------------------------------------- //
// FusedBatchNormInferenceFunctor implementation. //
// -------------------------------------------------------------------------- //
// Generic kernel, that does all computations by converting input to U data
// type. We use it when CUDA architecture doesn't have fast arithmetic fot the
// T data type (e.g. no fp16 in old GPU generations).
template <typename T, typename U, TensorFormat tensor_format,
bool add_side_input, FusedBatchNormActivationMode activation_mode,
bool is_generic_kernel>
struct FusedBatchNormInferenceKernel {
static_assert(tensor_format == FORMAT_NHWC || tensor_format == FORMAT_NCHW,
"Unsupported data format");
__device__ static void run(int32 count, int32 channels_size,
int32 inner_dim_size, const T* in, const U* scale,
const U* offset, const U* mean, const U* var,
const T* side_input, float epsilon, T* out) {
int32 index = blockIdx.x * blockDim.x + threadIdx.x;
const int32 total_device_threads = gridDim.x * blockDim.x;
while (index < count) {
const int channel = (tensor_format == FORMAT_NHWC)
? index % channels_size
: (index / inner_dim_size) % channels_size;
U in_v = U(in[index]);
U scale_v = scale[channel];
U offset_v = offset[channel];
U mean_v = mean[channel];
U var_v = var[channel];
U scaling_factor_v = rsqrt(var_v + epsilon) * scale_v;
static_assert(std::is_same<U, float>::value, "U data type must be float");
U shifted_v = fmaf(in_v - mean_v, scaling_factor_v, offset_v);
if (add_side_input) {
shifted_v += U(side_input[index]);
}
if (activation_mode == FusedBatchNormActivationMode::kIdentity) {
out[index] = T(shifted_v);
} else if (activation_mode == FusedBatchNormActivationMode::kRelu) {
out[index] = T(shifted_v < U(0) ? U(0) : shifted_v);
}
index += total_device_threads;
}
}
};
// Specialization for T=Eigen::half and U=float.
template <TensorFormat tensor_format, bool add_side_input,
FusedBatchNormActivationMode activation_mode>
struct FusedBatchNormInferenceKernel<Eigen::half, float, tensor_format,
add_side_input, activation_mode,
/*is_generic_kernel=*/false> {
using T = Eigen::half;
using U = float;
// If CUDA architecture doesn't support fast fp16 computation, we will
// fallback on generic kernel defined above.
using GenericKernel =
FusedBatchNormInferenceKernel<T, U, tensor_format, add_side_input,
activation_mode,
/*is_generic_kernel=*/true>;
__device__ static void run(int32 count, int32 channels_size,
int32 inner_dim_size, const T* in, const U* scale,
const U* offset, const U* mean, const U* var,
const T* side_input, float epsilon, T* out) {
// Old GPUs do not have (or have very slow) fp16 arithmetic.
#if __CUDA_ARCH__ >= 610
int32 index = blockIdx.x * blockDim.x + threadIdx.x;
const int32 total_device_threads = gridDim.x * blockDim.x;
int32 half2_count = count >> 1;
half epsilon_h = __float2half(epsilon);
half2 epsilon_h2 = __float2half2_rn(epsilon);
const int32 max_channel_size = channels_size - 1;
while (index < half2_count) {
int32 channel[2];
if (tensor_format == FORMAT_NHWC) {
channel[0] = (2 * index) % channels_size;
channel[1] = channel[0] == max_channel_size ? 0 : channel[0] + 1;
} else {
channel[0] = ((2 * index) / inner_dim_size) % channels_size;
channel[1] = ((2 * index + 1) / inner_dim_size) % channels_size;
}
half2 in_v = reinterpret_cast<const half2*>(in)[index];
half2 scale_v = __floats2half2_rn(scale[channel[0]], scale[channel[1]]);
half2 offset_v =
__floats2half2_rn(offset[channel[0]], offset[channel[1]]);
half2 mean_v = __floats2half2_rn(mean[channel[0]], mean[channel[1]]);
half2 var_v = __floats2half2_rn(var[channel[0]], var[channel[1]]);
half2 scaling_factor_v =
__hmul2(h2rsqrt(__hadd2(var_v, epsilon_h2)), scale_v);
half2 shifted_v =
__hfma2(__hsub2(in_v, mean_v), scaling_factor_v, offset_v);
if (add_side_input) {
shifted_v = __hadd2(shifted_v,
reinterpret_cast<const half2*>(side_input)[index]);
}
if (activation_mode == FusedBatchNormActivationMode::kIdentity) {
reinterpret_cast<half2*>(out)[index] = shifted_v;
} else if (activation_mode == FusedBatchNormActivationMode::kRelu) {
const half2 kZeroH = __float2half2_rn(0.f);
const half2 mask_h = __hgt2(shifted_v, kZeroH);
reinterpret_cast<half2*>(out)[index] = __hmul2(mask_h, shifted_v);
}
index += total_device_threads;
}
if ((count & 0x1) == 1 && index == half2_count) {
index = count - 1;
const int32 channel = (tensor_format == FORMAT_NHWC)
? index % channels_size
: (index / inner_dim_size) % channels_size;
half in_v = in[index];
half scale_v = __float2half(scale[channel]);
half offset_v = __float2half(offset[channel]);
half mean_v = __float2half(mean[channel]);
half var_v = __float2half(var[channel]);
half scaling_factor_v = __hmul(hrsqrt(__hadd(var_v, epsilon_h)), scale_v);
half shifted_v = __hfma(__hsub(in_v, mean_v), scaling_factor_v, offset_v);
if (add_side_input) {
shifted_v = __hadd(shifted_v, side_input[index]);
}
if (activation_mode == FusedBatchNormActivationMode::kIdentity) {
out[index] = shifted_v;
} else if (activation_mode == FusedBatchNormActivationMode::kRelu) {
const half kZeroH = __float2half(0.f);
const half mask_h = __hgt(shifted_v, kZeroH);
out[index] = __hmul(mask_h, shifted_v);
}
}
#else
GenericKernel::run(count, channels_size, inner_dim_size, in, scale, offset,
mean, var, side_input, epsilon, out);
#endif // __CUDA_ARCH__ >= 610
}
};
template <typename T, typename U, TensorFormat tensor_format,
bool add_side_input, FusedBatchNormActivationMode activation_mode>
__global__ void FusedBatchNormInferenceMetaKernel(
int32 count, int32 channels_size, int32 inner_dim_size, const T* in,
const U* scale, const U* offset, const U* mean, const U* var,
const T* side_input, float epsilon, T* out) {
// We prefer to run non-generic specialization, for the given types T and U.
// TODO(b/135435976): Temporary disable non-generic kernel implementation.
FusedBatchNormInferenceKernel<
T, U, tensor_format, add_side_input, activation_mode,
/*is_generic_kernel=*/true>::run(count, channels_size, inner_dim_size, in,
scale, offset, mean, var, side_input,
epsilon, out);
}
template <typename T, typename U>
struct FusedBatchNormInferenceFunctor<GPUDevice, T, U> {
void operator()(OpKernelContext* context, TensorFormat tensor_format,
typename TTypes<T, 4>::ConstTensor in,
typename TTypes<U>::ConstVec scale,
typename TTypes<U>::ConstVec offset,
typename TTypes<U>::ConstVec estimated_mean,
typename TTypes<U>::ConstVec estimated_variance,
typename TTypes<T, 4>::ConstTensor side_input, U epsilon,
FusedBatchNormActivationMode activation_mode,
typename TTypes<T, 4>::Tensor out) {
const auto& d = context->eigen_device<GPUDevice>();
const int32 count = out.size();
if (count == 0) return;
bool launched = false;
constexpr int32 kThreadInBlock = 512;
#define LAUNCH(DATA_FORMAT, ADD_SIDE_INPUT, ACTIVATION, CHANNEL_SIZE, \
INNER_DIM_SIZE) \
launched = true; \
\
GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize( \
std::is_same<T, Eigen::half>::value ? Eigen::divup(count, 2) : count, d, \
FusedBatchNormInferenceMetaKernel<T, U, DATA_FORMAT, ADD_SIDE_INPUT, \
ACTIVATION>, \
0, kThreadInBlock); \
\
TF_CHECK_OK(GpuLaunchKernel( \
FusedBatchNormInferenceMetaKernel<T, U, DATA_FORMAT, ADD_SIDE_INPUT, \
ACTIVATION>, \
config.block_count, config.thread_per_block, 0, d.stream(), count, \
CHANNEL_SIZE, INNER_DIM_SIZE, in.data(), scale.data(), offset.data(), \
estimated_mean.data(), estimated_variance.data(), side_input.data(), \
epsilon, out.data()));
const bool no_side_input = side_input.dimensions().TotalSize() == 0;
const bool add_side_input = side_input.dimensions().TotalSize() != 0;
using Activation = FusedBatchNormActivationMode;
const bool no_activation = activation_mode == Activation::kIdentity;
const bool relu_activation = activation_mode == Activation::kRelu;
if (tensor_format == FORMAT_NHWC) {
const int c = in.dimensions()[3];
if (no_activation && no_side_input) {
LAUNCH(FORMAT_NHWC, false, Activation::kIdentity, c, 1);
} else if (relu_activation && no_side_input) {
LAUNCH(FORMAT_NHWC, false, Activation::kRelu, c, 1);
} else if (no_activation && add_side_input) {
LAUNCH(FORMAT_NHWC, true, Activation::kIdentity, c, 1);
} else if (relu_activation && add_side_input) {
LAUNCH(FORMAT_NHWC, true, Activation::kRelu, c, 1);
}
} else if (tensor_format == FORMAT_NCHW) {
const int c = in.dimensions()[1];
const int inner = in.dimensions()[2] * in.dimensions()[3];
if (no_activation && no_side_input) {
LAUNCH(FORMAT_NCHW, false, Activation::kIdentity, c, inner);
} else if (relu_activation && no_side_input) {
LAUNCH(FORMAT_NCHW, false, Activation::kRelu, c, inner);
} else if (no_activation && add_side_input) {
LAUNCH(FORMAT_NCHW, true, Activation::kIdentity, c, inner);
} else if (relu_activation && add_side_input) {
LAUNCH(FORMAT_NCHW, true, Activation::kRelu, c, inner);
}
}
#undef LAUNCH
OP_REQUIRES(context, launched,
errors::InvalidArgument("Unsupported launch configuration"));
}
};
template struct FusedBatchNormInferenceFunctor<GPUDevice, float, float>;
template struct FusedBatchNormInferenceFunctor<GPUDevice, Eigen::half, float>;
} // namespace functor
} // namespace tensorflow
#else
#include "tensorflow/core/kernels/fused_batch_norm_op.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM