blob: 970b4edca39c417f57e1fd30786ff7d4e0893c14 [file] [log] [blame]
#pragma once
#include <ATen/Tensor.h>
#include <ATen/native/vulkan/VulkanCommon.h>
#include <torch/custom_class.h>
namespace at {
namespace native {
namespace vulkan {
using SerializationTypeConv2dPrePack = std::tuple<
Tensor,
c10::optional<Tensor>,
std::vector<int64_t>,
std::vector<int64_t>,
std::vector<int64_t>,
int64_t,
c10::optional<Scalar>,
c10::optional<Scalar>>;
class Conv2dOpContext : public torch::jit::CustomClassHolder {
protected:
Tensor orig_weight_;
c10::optional<Tensor> orig_bias_;
std::vector<int64_t> stride_;
std::vector<int64_t> padding_;
std::vector<int64_t> dilation_;
int64_t groups_;
c10::optional<Scalar> output_min_;
c10::optional<Scalar> output_max_;
public:
SerializationTypeConv2dPrePack unpack() {
return std::make_tuple(
orig_weight_,
orig_bias_,
stride_,
padding_,
dilation_,
groups_,
output_min_,
output_max_);
}
virtual Tensor run(const Tensor& input) = 0;
};
class VulkanConv2dOpContext final : public Conv2dOpContext {
private:
ContextConv2D op_context_;
public:
VulkanConv2dOpContext(
Tensor&& weight,
c10::optional<Tensor>&& bias,
std::vector<int64_t>&& padding,
std::vector<int64_t>&& stride,
std::vector<int64_t>&& dilation,
uint64_t groups,
const c10::optional<Scalar>& min,
const c10::optional<Scalar>& max,
ContextConv2D&& op_context)
: op_context_(std::move(op_context)) {
orig_weight_ = std::move(weight);
orig_bias_ = std::move(bias);
padding_ = std::move(padding);
stride_ = std::move(stride);
dilation_ = std::move(dilation);
groups_ = groups;
output_min_ = min;
output_max_ = max;
}
Tensor run(const Tensor& input) override;
static c10::intrusive_ptr<Conv2dOpContext> create_context(
Tensor&& weight,
c10::optional<Tensor>&& bias,
std::vector<int64_t>&& padding,
std::vector<int64_t>&& stride,
std::vector<int64_t>&& dilation,
int64_t groups,
const c10::optional<Scalar>& output_min,
const c10::optional<Scalar>& output_max);
};
} // namespace vulkan
} // namespace native
} // namespace at