blob: 2ca57f000dd84519ba717dc62e1ca3feaa17f051 [file] [log] [blame]
#include "caffe2/core/context_gpu.h"
#include "caffe2/core/cudnn_wrappers.h"
#include "caffe2/core/types.h"
#include "caffe2/operators/softmax_op.h"
namespace caffe2 {
namespace {
constexpr int NUM_DESCRIPTORS = 2;
constexpr int GRADIENT_NUM_DESCRIPTORS = 3;
constexpr int BOTTOM_DESC_ID = 0;
constexpr int TOP_DESC_ID = 1;
constexpr int TOP_GRADIENT_DESC_ID = 2;
} // namespace
class CuDNNSoftmaxOp final : public Operator<CUDAContext> {
public:
template <class... Args>
explicit CuDNNSoftmaxOp(Args&&... args)
: Operator<CUDAContext>(std::forward<Args>(args)...),
cudnn_wrapper_(&context_),
axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
}
~CuDNNSoftmaxOp() override {
CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(desc_));
}
template <typename T>
bool DoRunWithType() {
auto& X = Input(0);
const auto canonical_axis = X.canonical_axis_index(axis_);
const int N = X.size_to_dim(canonical_axis);
const int D = X.size_from_dim(canonical_axis);
auto* Y = Output(0, X.sizes(), at::dtype<T>());
auto* Y_data = Y->template mutable_data<T>();
if (N == 0 || D == 0) {
return true;
}
if (dims_ != X.sizes()) {
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
desc_,
GetCudnnTensorFormat(StorageOrder::NCHW),
cudnnTypeWrapper<T>::type,
N,
D,
1,
1));
dims_ = X.sizes().vec();
}
CUDNN_ENFORCE(cudnnSoftmaxForward(
cudnn_wrapper_.inline_cudnn_handle(),
CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_INSTANCE,
cudnnTypeWrapper<T>::kOne(),
desc_,
X.template data<T>(),
cudnnTypeWrapper<T>::kZero(),
desc_,
Y_data));
return true;
}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
}
protected:
CuDNNWrapper cudnn_wrapper_;
int axis_;
cudnnTensorDescriptor_t desc_;
vector<int64_t> dims_;
};
class CuDNNSoftmaxGradientOp final : public Operator<CUDAContext> {
public:
template <class... Args>
explicit CuDNNSoftmaxGradientOp(Args&&... args)
: Operator<CUDAContext>(std::forward<Args>(args)...),
cudnn_wrapper_(&context_),
axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
}
~CuDNNSoftmaxGradientOp() override {
CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(desc_));
}
template <typename T>
bool DoRunWithType() {
auto& Y = Input(0);
auto& dY = Input(1);
const auto canonical_axis = Y.canonical_axis_index(axis_);
const int N = Y.size_to_dim(canonical_axis);
const int D = Y.size_from_dim(canonical_axis);
CHECK_EQ(Y.sizes(), dY.sizes());
auto* dX = Output(0, Y.sizes(), at::dtype<T>());
auto* dX_data = dX->template mutable_data<T>();
if (N == 0 || D == 0) {
return true;
}
if (dims_ != Y.sizes()) {
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
desc_,
GetCudnnTensorFormat(StorageOrder::NCHW),
cudnnTypeWrapper<T>::type,
N,
D,
1,
1));
dims_ = Y.sizes().vec();
}
CUDNN_ENFORCE(cudnnSoftmaxBackward(
cudnn_wrapper_.inline_cudnn_handle(),
CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_INSTANCE,
cudnnTypeWrapper<T>::kOne(),
desc_,
Y.template data<T>(),
desc_,
dY.template data<T>(),
cudnnTypeWrapper<T>::kZero(),
desc_,
dX_data));
return true;
}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
}
protected:
CuDNNWrapper cudnn_wrapper_;
int axis_;
cudnnTensorDescriptor_t desc_;
vector<int64_t> dims_;
};
namespace {
REGISTER_CUDNN_OPERATOR(Softmax, CuDNNSoftmaxOp);
REGISTER_CUDNN_OPERATOR(SoftmaxGradient, CuDNNSoftmaxGradientOp);
} // namespace
} // namespace caffe2