blob: a328a62d5f772dc17a2c6e17afcd33f103c04655 [file] [log] [blame]
#pragma once
#include <Python.h>
#include <ATen/ATen.h>
#include <memory>
#include <vector>
#include <iostream>
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/symbolic.h"
#ifdef WITH_CUDNN
#include "torch/csrc/cudnn/Conv.h"
#else
namespace torch { namespace cudnn {
struct Convolution {};
}}
#endif
namespace torch { namespace autograd {
struct ConvParams {
std::vector<int> stride;
std::vector<int> padding;
std::vector<int> dilation;
bool transposed;
std::vector<int> output_padding;
int groups;
bool benchmark;
bool cudnn_enabled;
bool is_dilated() const;
bool is_output_padding_neg() const;
bool is_output_padding_big() const;
bool is_padding_neg() const;
void view1d_as_2d();
bool use_cudnn(const at::Tensor& input) const;
};
struct ConvForward : public ForwardFunction<>, public ConvParams, public HasSymbolic {
explicit ConvForward(ConvParams params) : ConvParams(std::move(params)) {}
virtual std::string name() override;
virtual variable_list apply(const variable_list& inputs) override;
virtual jit::node_list symbolic(SymbolicContext* ctx, jit::node_list inputs) override;
std::vector<int64_t> output_size(at::Tensor& input, at::Tensor& weight);
};
struct ConvBackward : public Function, public ConvParams {
ConvBackward(
FunctionFlags flags,
ConvParams params,
const std::shared_ptr<Variable>& input,
const std::shared_ptr<Variable>& weight,
const std::shared_ptr<Variable>& bias,
tensor_list columns,
tensor_list ones,
std::unique_ptr<torch::cudnn::Convolution> convolution)
: Function(std::move(flags))
, ConvParams(std::move(params))
, convolution(std::move(convolution)) {
if (is_executable) {
this->input_ = input->save(this);
this->weight_ = weight->save(this);
this->bias_ = Variable::save_opt(bias.get(), this);
this->columns = std::move(columns);
this->ones = std::move(ones);
}
}
virtual variable_list apply(const variable_list& gradOutputs) override;
virtual void releaseVariables() override;
SavedVariable input_;
SavedVariable weight_;
SavedVariable bias_;
tensor_list columns;
tensor_list ones;
std::unique_ptr<torch::cudnn::Convolution> convolution;
};
struct ConvBackwardBackward : public Function, public ConvParams {
ConvBackwardBackward(
FunctionFlags flags,
ConvParams params,
const std::shared_ptr<Variable>& input,
const std::shared_ptr<Variable>& weight,
const std::shared_ptr<Variable>& bias,
const std::shared_ptr<Variable>& grad_output)
: Function(std::move(flags))
, ConvParams(std::move(params)) {
if (is_executable) {
this->input_ = input->save(this);
this->weight_ = weight->save(this);
this->bias_ = Variable::save_opt(bias.get(), this);
this->grad_output_ = grad_output->save(this);
}
}
virtual variable_list apply(const variable_list& grad_grad_inputs) override;
virtual void releaseVariables() override;
SavedVariable input_;
SavedVariable weight_;
SavedVariable bias_;
SavedVariable grad_output_;
};
}} // namespace torch::autograd