blob: f1746d2fa0dd6b370bb9ab924443579e6f21b0c8 [file] [log] [blame]
#pragma once
#include <memory>
#include <vector>
#include <THPP/THPP.h>
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
#ifdef WITH_CUDNN
#include "torch/csrc/cudnn/Conv.h"
#else
namespace torch { namespace cudnn {
struct Convolution {};
}}
#endif
namespace torch { namespace autograd {
struct ConvParams {
std::vector<int> stride;
std::vector<int> padding;
std::vector<int> dilation;
bool transposed;
std::vector<int> output_padding;
int groups;
bool benchmark;
bool cudnn_enabled;
bool is_dilated() const;
bool is_output_padding_neg() const;
bool is_padding_neg() const;
void view1d_as_2d();
bool use_cudnn(const thpp::Tensor& input) const;
};
struct ConvForward : public Function, public ConvParams {
explicit ConvForward(ConvParams params) : ConvParams(std::move(params)) {}
virtual variable_list apply(const variable_list& inputs) override;
std::vector<long> output_size(thpp::Tensor& input, thpp::Tensor& weight);
};
struct ConvBackward : public Function, public ConvParams {
ConvBackward(
FunctionFlags flags,
ConvParams params,
SavedVariable input,
SavedVariable weight,
SavedVariable bias,
tensor_list columns,
tensor_list ones,
std::unique_ptr<torch::cudnn::Convolution> convolution)
: Function(std::move(flags))
, ConvParams(std::move(params))
, convolution(std::move(convolution)) {
if (is_executable) {
this->input_ = std::move(input);
this->weight_ = std::move(weight);
this->bias_ = std::move(bias);
this->columns = std::move(columns);
this->ones = std::move(ones);
}
}
virtual variable_list apply(const variable_list& gradOutputs) override;
virtual void releaseVariables() override;
SavedVariable input_;
SavedVariable weight_;
SavedVariable bias_;
tensor_list columns;
tensor_list ones;
std::unique_ptr<torch::cudnn::Convolution> convolution;
};
}}