|  | /* | 
|  | * CuDNN ReLU extension. Simple function but contains the general structure of | 
|  | * most CuDNN extensions: | 
|  | * 1) Check arguments. torch::check* functions provide a standard way to | 
|  | * validate input and provide pretty errors. 2) Create descriptors. Most CuDNN | 
|  | * functions require creating and setting a variety of descriptors. 3) Apply the | 
|  | * CuDNN function. 4) Destroy your descriptors. 5) Return something (optional). | 
|  | */ | 
|  |  | 
|  | #include <torch/extension.h> | 
|  |  | 
|  | #include <ATen/cuda/Exceptions.h> // for CUDNN_CHECK | 
|  | #include <ATen/cudnn/Descriptors.h> // for TensorDescriptor | 
|  | #include <ATen/cudnn/Handle.h> // for getCudnnHandle | 
|  |  | 
|  | // Name of function in python module and name used for error messages by | 
|  | // torch::check* functions. | 
|  | const char* cudnn_relu_name = "cudnn_relu"; | 
|  |  | 
|  | // Check arguments to cudnn_relu | 
|  | void cudnn_relu_check( | 
|  | const torch::Tensor& inputs, | 
|  | const torch::Tensor& outputs) { | 
|  | // Create TensorArgs. These record the names and positions of each tensor as a | 
|  | // parameter. | 
|  | torch::TensorArg arg_inputs(inputs, "inputs", 0); | 
|  | torch::TensorArg arg_outputs(outputs, "outputs", 1); | 
|  | // Check arguments. No need to return anything. These functions with throw an | 
|  | // error if they fail. Messages are populated using information from | 
|  | // TensorArgs. | 
|  | torch::checkContiguous(cudnn_relu_name, arg_inputs); | 
|  | torch::checkScalarType(cudnn_relu_name, arg_inputs, torch::kFloat); | 
|  | torch::checkBackend(cudnn_relu_name, arg_inputs.tensor, torch::Backend::CUDA); | 
|  | torch::checkContiguous(cudnn_relu_name, arg_outputs); | 
|  | torch::checkScalarType(cudnn_relu_name, arg_outputs, torch::kFloat); | 
|  | torch::checkBackend( | 
|  | cudnn_relu_name, arg_outputs.tensor, torch::Backend::CUDA); | 
|  | torch::checkSameSize(cudnn_relu_name, arg_inputs, arg_outputs); | 
|  | } | 
|  |  | 
|  | void cudnn_relu(const torch::Tensor& inputs, const torch::Tensor& outputs) { | 
|  | // Most CuDNN extensions will follow a similar pattern. | 
|  | // Step 1: Check inputs. This will throw an error if inputs are invalid, so no | 
|  | // need to check return codes here. | 
|  | cudnn_relu_check(inputs, outputs); | 
|  | // Step 2: Create descriptors | 
|  | cudnnHandle_t cuDnn = torch::native::getCudnnHandle(); | 
|  | // Note: 4 is minimum dim for a TensorDescriptor. Input and output are same | 
|  | // size and type and contiguous, so one descriptor is sufficient. | 
|  | torch::native::TensorDescriptor input_tensor_desc(inputs, 4); | 
|  | cudnnActivationDescriptor_t activationDesc; | 
|  | // Note: Always check return value of cudnn functions using CUDNN_CHECK | 
|  | AT_CUDNN_CHECK(cudnnCreateActivationDescriptor(&activationDesc)); | 
|  | AT_CUDNN_CHECK(cudnnSetActivationDescriptor( | 
|  | activationDesc, | 
|  | /*mode=*/CUDNN_ACTIVATION_RELU, | 
|  | /*reluNanOpt=*/CUDNN_PROPAGATE_NAN, | 
|  | /*coef=*/1.)); | 
|  | // Step 3: Apply CuDNN function | 
|  | float alpha = 1.; | 
|  | float beta = 0.; | 
|  | AT_CUDNN_CHECK(cudnnActivationForward( | 
|  | cuDnn, | 
|  | activationDesc, | 
|  | &alpha, | 
|  | input_tensor_desc.desc(), | 
|  | inputs.data_ptr(), | 
|  | &beta, | 
|  | input_tensor_desc.desc(), // output descriptor same as input | 
|  | outputs.data_ptr())); | 
|  | // Step 4: Destroy descriptors | 
|  | AT_CUDNN_CHECK(cudnnDestroyActivationDescriptor(activationDesc)); | 
|  | // Step 5: Return something (optional) | 
|  | } | 
|  |  | 
|  | // Create the pybind11 module | 
|  | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | 
|  | // Use the same name as the check functions so error messages make sense | 
|  | m.def(cudnn_relu_name, &cudnn_relu, "CuDNN ReLU"); | 
|  | } |