blob: 696215cb6d4cdb219b6152bb7dbdce724ca7d5be [file] [log] [blame]
#ifndef THP_CUDNN_DESCRIPTORS_INC
#define THP_CUDNN_DESCRIPTORS_INC
#include "Exceptions.h"
#include <cudnn.h>
namespace torch { namespace cudnn {
struct TensorDescriptor
{
cudnnTensorDescriptor_t desc;
TensorDescriptor() : desc(NULL) {
CHECK(cudnnCreateTensorDescriptor(&desc));
}
TensorDescriptor(const TensorDescriptor&) = delete;
TensorDescriptor(TensorDescriptor&& ref)
{
desc = ref.desc;
ref.desc = NULL;
}
~TensorDescriptor() {
cudnnDestroyTensorDescriptor(desc);
}
void set(cudnnDataType_t dataType, int dim, int* size, int* stride) {
CHECK(cudnnSetTensorNdDescriptor(desc, dataType, dim, size, stride));
}
};
struct FilterDescriptor
{
cudnnFilterDescriptor_t desc;
FilterDescriptor() : desc(NULL) {
CHECK(cudnnCreateFilterDescriptor(&desc));
}
FilterDescriptor(const FilterDescriptor&) = delete;
FilterDescriptor(FilterDescriptor&& ref)
{
desc = ref.desc;
ref.desc = NULL;
}
~FilterDescriptor() {
cudnnDestroyFilterDescriptor(desc);
}
void set(cudnnDataType_t dataType, int dim, int* size) {
CHECK(cudnnSetFilterNdDescriptor(desc, dataType, CUDNN_TENSOR_NCHW, dim, size));
}
};
struct ConvolutionDescriptor
{
cudnnConvolutionDescriptor_t desc;
ConvolutionDescriptor() : desc(NULL) {
CHECK(cudnnCreateConvolutionDescriptor(&desc));
}
ConvolutionDescriptor(const ConvolutionDescriptor&) = delete;
ConvolutionDescriptor(ConvolutionDescriptor&& ref)
{
desc = ref.desc;
ref.desc = NULL;
}
~ConvolutionDescriptor() {
cudnnDestroyConvolutionDescriptor(desc);
}
void set(cudnnDataType_t dataType, int dim, int* pad, int* stride, int * upscale) {
cudnnDataType_t mathType = dataType;
if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT;
CHECK(cudnnSetConvolutionNdDescriptor(desc, dim, pad, stride, upscale,
CUDNN_CROSS_CORRELATION, mathType));
}
};
union Constant
{
float f;
double d;
Constant(cudnnDataType_t dataType, double value) {
if (dataType == CUDNN_DATA_HALF || dataType == CUDNN_DATA_FLOAT) {
f = (float) value;
} else {
d = value;
}
}
};
inline int dataSize(cudnnDataType_t dataType)
{
switch (dataType) {
case CUDNN_DATA_HALF: return 2;
case CUDNN_DATA_FLOAT: return 4;
default: return 8;
}
}
}} // namespace
#endif