blob: 532de9b769e6b1b5b80d67abcf52898bd254de9b [file] [log] [blame]
#include "ATen/ATen.h"
#include "ATen/NativeFunctions.h"
#include "ATen/Config.h"
#if AT_CUDNN_ENABLED()
#include "THC/THC.h"
#include "ATen/cudnn/cudnn-wrapper.h"
#endif
namespace at { namespace native {
namespace {
void check_dims_match_num_input_features(const char* arg_name, int64_t expected, int64_t actual){
if (actual != expected){
std::stringstream ss;
ss << arg_name << " should contain " << expected << " elements not " << actual ;
throw std::runtime_error(ss.str());
}
}
}
Tensor batch_norm(
const Tensor& input, const Tensor& weight /* optional */, const Tensor& bias /* optional */,
const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
bool training, double momentum, double eps, bool cudnn_enabled) {
auto num_features = input.sizes()[1];
if (running_mean.defined()) {
check_dims_match_num_input_features("running_mean", num_features, running_mean.numel());
} else if (!training) {
throw std::runtime_error("running_mean must be defined in evaluation mode");
}
if (running_var.defined()) {
check_dims_match_num_input_features("running_var", num_features, running_var.numel());
} else if (!training) {
throw std::runtime_error("running_var must be defined in evaluation mode");
}
if (weight.defined()) {
check_dims_match_num_input_features("weight", num_features, weight.numel());
}
if (bias.defined()) {
check_dims_match_num_input_features("bias", num_features, bias.numel());
}
bool use_cudnn = false;
#if AT_CUDNN_ENABLED()
use_cudnn = (input.type().is_cuda()
&& (input.type().scalarType() != at::kHalf
|| weight.type().scalarType() == at::kFloat)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& input.size(0) <= 131070
&& cudnn_enabled && CUDNN_VERSION >= 5110L);
#endif
#if AT_CUDNN_ENABLED()
if (use_cudnn && eps >= CUDNN_BN_MIN_EPSILON) {
return std::get<0>(at::cudnn_batch_norm(
input, weight, bias,
running_mean, running_var,
training, momentum, eps));
}
#endif
return at::thnn_batch_norm(
input, weight, bias,
running_mean, running_var, training, momentum, eps);
}
}} // at::native