[vulkan] convolution old prepacking via cpu-shader (#48330)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48330
Test Plan: Imported from OSS
Reviewed By: SS-JIA
Differential Revision: D25131500
Pulled By: IvanKobzarev
fbshipit-source-id: b11edb94a78f5d6283c7be1887d72a4ca624a9ab
diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp_1x.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp_1x.glsl
new file mode 100644
index 0000000..5cef89c
--- /dev/null
+++ b/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp_1x.glsl
@@ -0,0 +1,66 @@
+#version 450 core
+#define PRECISION $precision
+layout(std430) buffer;
+layout(std430) uniform;
+layout(set = 0, rgba32f, binding = 0) writeonly PRECISION uniform image3D uOutput;
+layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
+layout(set = 0, binding = 2) uniform PRECISION sampler3D uKernel;
+layout(set = 0, binding = 3) readonly buffer bias {
+ vec4 data[];
+}
+uBias;
+layout(set = 0, binding = 4) uniform constBlock {
+ ivec2 padding;
+ ivec2 kernelSize;
+ ivec2 stride;
+ ivec2 dilate;
+ ivec4 outputSize;
+ ivec4 inputSize;
+ float outputMin;
+ float outputMax;
+}
+uConstBlock;
+
+#define UP_DIV(x, y) (((x) + (y)-1) / (y))
+
+layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in;
+
+void main() {
+ ivec3 pos = ivec3(gl_GlobalInvocationID);
+ if (all(lessThan(pos, uConstBlock.outputSize.xyz))) {
+ int kernelX = uConstBlock.kernelSize.x;
+ int kernelY = uConstBlock.kernelSize.y;
+ ivec3 inputSize = uConstBlock.inputSize.xyz;
+ ivec2 s0 = pos.xy * uConstBlock.stride - uConstBlock.padding;
+ int fx, fy, fz;
+ ivec2 sfxy = max(ivec2(0), (UP_DIV(-s0, uConstBlock.dilate)));
+ ivec2 efxy =
+ min(uConstBlock.kernelSize,
+ UP_DIV(uConstBlock.inputSize.xy - s0, uConstBlock.dilate));
+ vec4 color = uBias.data[pos.z];
+ int kY = pos.z;
+ int strideX = uConstBlock.stride.x;
+ for (fy = sfxy.y; fy < efxy.y; ++fy) {
+ int sy = fy * uConstBlock.dilate.y + s0.y;
+ for (fx = 0; fx < kernelX; ++fx) {
+ int kZ = fx + fy * kernelX;
+ int sx = fx * uConstBlock.dilate.x + s0.x;
+ fz = 0;
+ for (; fz < inputSize.z; ++fz) {
+ int kX = 4 * fz;
+ vec4 k0 = texelFetch(uKernel, ivec3(kX + 0, kY, kZ), 0);
+ vec4 k1 = texelFetch(uKernel, ivec3(kX + 1, kY, kZ), 0);
+ vec4 k2 = texelFetch(uKernel, ivec3(kX + 2, kY, kZ), 0);
+ vec4 k3 = texelFetch(uKernel, ivec3(kX + 3, kY, kZ), 0);
+
+ mat4 k = mat4(k0, k1, k2, k3);
+
+ color += k * texelFetch(uInput, ivec3(sx, sy, fz), 0);
+ }
+ }
+ }
+ vec4 outputMin = vec4(uConstBlock.outputMin);
+ vec4 outputMax = vec4(uConstBlock.outputMax);
+ imageStore(uOutput, ivec3(pos.x, pos.y, pos.z), clamp(color, outputMin, outputMax));
+ }
+}
diff --git a/aten/src/ATen/native/vulkan/ops/Common.h b/aten/src/ATen/native/vulkan/ops/Common.h
index 3c9b2e8..d8b9905 100644
--- a/aten/src/ATen/native/vulkan/ops/Common.h
+++ b/aten/src/ATen/native/vulkan/ops/Common.h
@@ -35,6 +35,10 @@
};
};
+struct Experimentation {
+ static constexpr bool kUseConv2dOldApi = true;
+};
+
} // namespace ops
} // namespace vulkan
} // namespace native
diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.cpp b/aten/src/ATen/native/vulkan/ops/Convolution.cpp
index c549468..0a12c7b 100644
--- a/aten/src/ATen/native/vulkan/ops/Convolution.cpp
+++ b/aten/src/ATen/native/vulkan/ops/Convolution.cpp
@@ -1,4 +1,5 @@
#include <ATen/native/vulkan/ops/Convolution.h>
+#include <ATen/native/vulkan/api/Utils.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/native/utils/ParamUtils.h>
#include <ATen/native/vulkan/ops/Persistent.h>
@@ -66,6 +67,91 @@
// General
//
+ if (Experimentation::kUseConv2dOldApi) {
+ const uint32_t OC = src_filter[Layout::Filter::output];
+ const uint32_t OC_4 = at::native::vulkan::api::utils::div_up(OC, 4u);
+ const uint32_t C = src_filter[Layout::Filter::input];
+ const uint32_t C_4 = at::native::vulkan::api::utils::div_up(C, 4u);
+ const uint32_t KH = src_filter[Layout::Filter::height];
+ const uint32_t KW = src_filter[Layout::Filter::width];
+
+ vTensor v_weight{
+ api::context(),
+ &pool,
+ {
+ 1,
+ 4 * KH * KW,
+ OC_4,
+ 4 * C_4
+ },
+ weight.options(),
+ };
+
+ using Future = vTensor::Future<float, vTensor::Access::Write>;
+ Future v_weight_future = v_weight.host<float, vTensor::Access::Write>();
+ Future::Payload v_weight_payload = v_weight_future.wait();
+
+ float* const dst_weight_ptr = v_weight_payload.get();
+ memset(dst_weight_ptr, 0, v_weight.nbytes());
+
+ const float* src = src_weight_ptr;
+ float* const dst = dst_weight_ptr;
+
+ {
+ uint32_t ridx = 0;
+ const uint32_t oc_4SizeNumel = KW * KH * C_4 * 16;
+ for (uint32_t oc = 0; oc < OC; ++oc) {
+ int oc_4 = oc / 4;
+ int oc_4_i = oc % 4;
+ float* dst_oc = dst + oc_4 * oc_4SizeNumel;
+ for (uint32_t ic = 0; ic < C; ++ic) {
+ int ic_4 = ic / 4;
+ int ic_4_i = ic % 4;
+ float* dst_ic = dst_oc + ic_4 * KW * KH * 16;
+ for (uint32_t ky = 0; ky < KH; ++ky) {
+ float* dst_ky = dst_ic + ky * KW * 16;
+ for (uint32_t kx = 0; kx < KW; ++kx) {
+ float* dst_kx = dst_ky + kx * 16;
+ dst_kx[4 * ic_4_i + oc_4_i] = src[ridx++];
+ }
+ }
+ }
+ }
+
+ // shader KO4C4HW_to_image
+ float image[4 * C_4][OC_4][KH * KW][4];
+ memset(image, 0.f, 16 * C_4 * OC_4 * KH * KW * sizeof(float));
+ for (uint32_t sx = 0; sx < C_4; ++sx) {
+ for (uint32_t sy = 0; sy < OC_4; ++sy) {
+ for (uint32_t sz = 0; sz < (KH * KW); ++sz) {
+ for (uint32_t vi = 0; vi < 4; ++vi) {
+ int bufferVIdx = 4 * sx * KH * KW + 4 * sy * C_4 * KH * KW + 4 * sz;
+ image[4 * sx + 0][sy][sz][vi] = dst[4 * (bufferVIdx + 0) + vi];
+ image[4 * sx + 1][sy][sz][vi] = dst[4 * (bufferVIdx + 1) + vi];
+ image[4 * sx + 2][sy][sz][vi] = dst[4 * (bufferVIdx + 2) + vi];
+ image[4 * sx + 3][sy][sz][vi] = dst[4 * (bufferVIdx + 3) + vi];
+ }
+ }
+ }
+ }
+
+ // inverse function of nchw_to_image
+ const uint32_t W = 4 * C_4;
+ const uint32_t H = OC_4;
+ const uint32_t D = KH * KW;
+ for (uint32_t sx = 0; sx < W; ++sx) {
+ for (uint32_t sy = 0; sy < H; ++sy) {
+ for (uint32_t sz = 0; sz < D; ++sz) {
+ for (uint32_t szvi = 0; szvi < 4; ++szvi) {
+ dst_weight_ptr[W * sy + sx + (4 * sz + szvi) * W * H] = image[sx][sy][sz][szvi];
+ }
+ }
+ }
+ }
+ }
+
+ return v_weight;
+ }
vTensor v_weight{
api::context(),
@@ -624,6 +710,99 @@
};
}
+void conv2d_old(
+ api::Context* const context,
+ api::Command::Buffer& command_buffer,
+ vTensor& v_output,
+ const vTensor& v_input,
+ const vTensor& v_weight,
+ const vTensor& v_bias,
+ const IntArrayRef filter,
+ const IntArrayRef stride,
+ const IntArrayRef padding,
+ const IntArrayRef dilation,
+ const float output_min,
+ const float output_max) {
+
+ using namespace api::utils;
+
+ if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) {
+ const int32_t W = v_input.extents().data[0];
+ const int32_t H = v_input.extents().data[1];
+ const int32_t C_4 = v_input.extents().data[2];
+ const int32_t C = 4 * C_4;
+
+ const int32_t OW = v_output.extents().data[0];
+ const int32_t OH = v_output.extents().data[1];
+ const int32_t OC_4 = v_output.extents().data[2];
+ const int32_t OC = 4 * OC_4;
+
+ const struct {
+ int32_t padding_x, padding_y;
+ int32_t kernel_x, kernel_y;
+ int32_t stride_x, stride_y;
+ int32_t dilate_x, dilate_y;
+ int32_t outputSize[4];
+ int32_t inputSize[4];
+ float outputMin;
+ float outputMax;
+ } block {
+ safe_downcast<int32_t>(padding[Layout::Parameter::width]),
+ safe_downcast<int32_t>(padding[Layout::Parameter::height]),
+ safe_downcast<int32_t>(filter[Layout::Filter::width]),
+ safe_downcast<int32_t>(filter[Layout::Filter::height]),
+ safe_downcast<int32_t>(stride[Layout::Parameter::width]),
+ safe_downcast<int32_t>(stride[Layout::Parameter::height]),
+ safe_downcast<int32_t>(dilation[Layout::Parameter::width]),
+ safe_downcast<int32_t>(dilation[Layout::Parameter::height]),
+ { OW, OH, OC_4, OC },
+ { W, H, C_4, C },
+ output_min,
+ output_max,
+ };
+
+ context->dispatch(
+ command_buffer,
+ {
+ VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
+ VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
+ VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
+ VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
+ VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
+ },
+ VK_KERNEL(conv2d_nogroup_clamp),
+ //VK_KERNEL(conv2d_nogroup_clamp_1x),
+ v_output.extents(),
+ // Write-only access bypasses synchronization but inserts appropriate
+ // barriers if necessary.
+ v_output.image(
+ command_buffer,
+ vTensor::Stage::Compute,
+ vTensor::Access::Write),
+ // Read-only access is implied on const tensors and triggers an async
+ // synchronization if necessary.
+ v_input.image(
+ command_buffer,
+ vTensor::Stage::Compute),
+ // Read-only access is implied on const tensors and triggers an async
+ // synchronization if necessary.
+ v_weight.image(
+ command_buffer,
+ vTensor::Stage::Compute),
+ // Read-only access is implied on const tensors and triggers an async
+ // synchronization if necessary.
+ v_bias.buffer(
+ command_buffer,
+ vTensor::Stage::Compute),
+ // Object lifetime is managed by the resource pool.
+ // It is OK not to keep track of the handle.
+ context->resource().pool.uniform(block).object);
+ }
+ else {
+ TORCH_CHECK(false, "Not implemented!");
+ }
+}
+
Tensor Conv2dOpContext::run(const Tensor& input_arg) const {
api::Context* const context = api::context();
@@ -664,34 +843,52 @@
packed_.output_min,
packed_.output_max);
}
- else if (is_pointwise(unpacked_.filter)) {
- conv2d_pointwise(
- context,
- command_buffer,
- v_output,
- v_input,
- packed_.v_weight,
- packed_.v_bias,
- packed_.filter,
- packed_.stride,
- packed_.padding,
- packed_.output_min,
- packed_.output_max);
- }
else {
- conv2d(
- context,
- command_buffer,
- v_output,
- v_input,
- packed_.v_weight,
- packed_.v_bias,
- packed_.filter,
- packed_.stride,
- packed_.padding,
- packed_.dilation,
- packed_.output_min,
- packed_.output_max);
+ if (Experimentation::kUseConv2dOldApi) {
+ conv2d_old(
+ context,
+ command_buffer,
+ v_output,
+ v_input,
+ packed_.v_weight,
+ packed_.v_bias,
+ packed_.filter,
+ packed_.stride,
+ packed_.padding,
+ packed_.dilation,
+ packed_.output_min,
+ packed_.output_max);
+ } else {
+ if (is_pointwise(unpacked_.filter)) {
+ conv2d_pointwise(
+ context,
+ command_buffer,
+ v_output,
+ v_input,
+ packed_.v_weight,
+ packed_.v_bias,
+ packed_.filter,
+ packed_.stride,
+ packed_.padding,
+ packed_.output_min,
+ packed_.output_max);
+ }
+ else {
+ conv2d(
+ context,
+ command_buffer,
+ v_output,
+ v_input,
+ packed_.v_weight,
+ packed_.v_bias,
+ packed_.filter,
+ packed_.stride,
+ packed_.padding,
+ packed_.dilation,
+ packed_.output_min,
+ packed_.output_max);
+ }
+ }
}
}
command_buffer.end();