[vulkan] enable prepacking for Batchnorm op (#88433)
Adds a `BatchNormPackedContext` so that the `batchnorm` op can use prepacking.
Differential Revision: [D40721546](https://our.internmc.facebook.com/intern/diff/D40721546/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88433
Approved by: https://github.com/manuelcandales
diff --git a/aten/src/ATen/native/vulkan/glsl/batchnorm.glsl b/aten/src/ATen/native/vulkan/glsl/batchnorm.glsl
index 6ec9342..0ec7dbd 100644
--- a/aten/src/ATen/native/vulkan/glsl/batchnorm.glsl
+++ b/aten/src/ATen/native/vulkan/glsl/batchnorm.glsl
@@ -1,37 +1,61 @@
#version 450 core
#define PRECISION $precision
-#define FORMAT $format
+#define FORMAT $format
layout(std430) buffer;
-/* Qualifiers: layout - storage - precision - memory */
+/*
+ * Output Image
+ */
+layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
-layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
-layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
-layout(set = 0, binding = 2) uniform PRECISION sampler3D uGamma;
-layout(set = 0, binding = 3) uniform PRECISION sampler3D uBeta;
-layout(set = 0, binding = 4) uniform PRECISION sampler3D uMean;
-layout(set = 0, binding = 5) uniform PRECISION sampler3D uVar;
-layout(set = 0, binding = 6) uniform PRECISION restrict Block {
- ivec3 isize;
- int channels_ext;
+/*
+ * Input Textures
+ */
+layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
+layout(set = 0, binding = 2) uniform PRECISION sampler3D uGamma;
+layout(set = 0, binding = 3) uniform PRECISION sampler3D uBeta;
+layout(set = 0, binding = 4) uniform PRECISION sampler3D uMean;
+layout(set = 0, binding = 5) uniform PRECISION sampler3D uVar;
+
+/*
+ * Params Buffer
+ */
+layout(set = 0, binding = 6) uniform PRECISION restrict Block {
+ // xyz contains extents of the output texture, w contains the number of
+ // channels divided by 4, rounded up.
+ ivec4 out_extents;
float eps;
-} uBlock;
+}
+uBlock;
+/*
+ * Local Work Group
+ */
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
+/*
+ * Computes a Batch normalization. Each shader invocation calculates the output
+ * at a single output location.
+ */
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
- if (all(lessThan(pos, uBlock.isize.xyz))) {
- const ivec3 chn = ivec3(0, 0, pos.z % uBlock.channels_ext);
- imageStore(
- uOutput,
- pos,
- (texelFetch(uInput, pos, 0)
- - texelFetch(uMean, chn, 0))
- / sqrt(texelFetch(uVar, chn, 0) + uBlock.eps)
- * texelFetch(uGamma, chn, 0)
- + texelFetch(uBeta, chn, 0));
+ // Return if this global position is outside output texture bounds
+ if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) {
+ return;
}
+
+ const ivec3 ch_pos = ivec3(0, 0, pos.z % uBlock.out_extents.w);
+
+ const vec4 in_tex = texelFetch(uInput, pos, 0);
+ const vec4 gamma_tex = texelFetch(uGamma, ch_pos, 0);
+ const vec4 beta_tex = texelFetch(uBeta, ch_pos, 0);
+ const vec4 mean_tex = texelFetch(uMean, ch_pos, 0);
+ const vec4 var_tex = texelFetch(uVar, ch_pos, 0);
+
+ const vec4 out_tex =
+ (in_tex - mean_tex) / sqrt(var_tex + uBlock.eps) * gamma_tex + beta_tex;
+
+ imageStore(uOutput, pos, out_tex);
}
diff --git a/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp b/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp
index 84828aa..d1fecca 100644
--- a/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp
+++ b/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp
@@ -1,109 +1,44 @@
-#include <ATen/native/vulkan/ops/Common.h>
+#include <ATen/Context.h>
+#include <ATen/native/vulkan/ops/Batchnorm.h>
#include <torch/library.h>
namespace at {
namespace native {
namespace vulkan {
namespace ops {
-namespace {
-using namespace api::utils;
+namespace batchnorm {
-Tensor batch_norm(
- const at::Tensor& input_arg,
- const c10::optional<Tensor>& weight_opt /* optional */,
- const c10::optional<Tensor>& bias_opt /* optional */,
- const c10::optional<Tensor>& running_mean_opt /* optional */,
- const c10::optional<Tensor>& running_var_opt /* optional */,
- bool training,
- double /* momentum, not used in eval mode */,
- double eps,
- bool /* cudnn_enable, deprecated */) {
- TORCH_CHECK(!training, "Vulkan batchnorm only supports evaluation mode.");
- TORCH_CHECK(
- weight_opt && weight_opt->defined() && bias_opt && bias_opt->defined(),
- "Vulkan batchnorm expects weight and bias arguments to be defined");
- TORCH_CHECK(
- running_mean_opt && running_mean_opt->defined(),
- "running_mean must be defined in evaluation mode.");
- TORCH_CHECK(
- running_var_opt && running_var_opt->defined(),
- "running_var must be defined in evaluation mode.");
- TORCH_CHECK(input_arg.dim() == 4, "Vulkan batchnorm expects 4-dim input!");
- TORCH_CHECK(
- get_dim<Dim4D::Channel>(input_arg) % 4 == 0,
- "Vulkan batchnorm expects channel dim to be multiple of 4!");
+struct Params final {
+ api::utils::ivec3 out_extents;
+ int32_t c4;
+ float eps;
+};
- const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
- const vTensor& v_input = convert(input);
- const IntArrayRef v_input_sizes = v_input.sizes();
+void record_op(
+ api::Context* const context,
+ vTensor& v_output,
+ const vTensor& v_input,
+ const vTensor& v_weight,
+ const vTensor& v_bias,
+ const vTensor& v_running_mean,
+ const vTensor& v_running_var,
+ const float eps) {
+ api::PipelineBarrier pipeline_barrier{};
- auto num_features = v_input.sizes()[1];
- auto channels_ext = num_features / 4;
+ api::utils::uvec3 global_size = v_output.extents();
+ api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
- const Tensor weight_opt_3d = weight_opt->reshape({num_features, 1, 1});
- const Tensor weight =
- weight_opt_3d.is_vulkan() ? weight_opt_3d : weight_opt_3d.vulkan();
- const vTensor& v_weight = convert(weight);
- TORCH_CHECK(
- weight.numel() == num_features,
- "weight tensor should contain ",
- num_features,
- " elements!");
+ uint32_t num_features = get_dim<Dim4D::Channel>(v_input.sizes());
+ uint32_t channels_ext = api::utils::div_up(num_features, 4u);
- const Tensor bias_opt_3d = bias_opt->reshape({num_features, 1, 1});
- const Tensor bias =
- bias_opt_3d.is_vulkan() ? bias_opt_3d : bias_opt_3d.vulkan();
- const vTensor& v_bias = convert(bias);
- TORCH_CHECK(
- bias.numel() == num_features,
- "bias tensor should contain ",
- num_features,
- " elements!");
-
- const Tensor running_mean_opt_3d =
- running_mean_opt->reshape({num_features, 1, 1});
- const Tensor running_mean = running_mean_opt_3d.is_vulkan()
- ? running_mean_opt_3d
- : running_mean_opt_3d.vulkan();
- const vTensor& v_running_mean = convert(running_mean);
- TORCH_CHECK(
- running_mean.numel() == num_features,
- "running mean tensor should contain ",
- num_features,
- " elements!");
-
- const Tensor running_var_opt_3d =
- running_var_opt->reshape({num_features, 1, 1});
- const Tensor running_var = running_var_opt_3d.is_vulkan()
- ? running_var_opt_3d
- : running_var_opt_3d.vulkan();
- const vTensor& v_running_var = convert(running_var);
- TORCH_CHECK(
- running_var.numel() == num_features,
- "running var tensor should contain ",
- num_features,
- " elements!");
-
- api::Context* const context = api::context();
-
- vTensor v_output{
- context,
- v_input_sizes,
- v_input.options(),
+ Params block{
+ api::utils::make_ivec3(v_output.extents()),
+ api::utils::safe_downcast<int32_t>(channels_ext),
+ eps,
};
- const struct Block final {
- uvec3 iextents;
- int32_t channels_ext;
- float epsilon;
- } block{
- v_output.extents(),
- safe_downcast<int32_t>(channels_ext),
- safe_downcast<float>(eps)};
-
api::UniformParamsBuffer params(context, block);
- api::PipelineBarrier pipeline_barrier{};
context->submit_compute_job(
// shader descriptor
@@ -111,9 +46,9 @@
// pipeline barrier
pipeline_barrier,
// global work group size
- v_output.extents(),
+ global_size,
// local work group size
- adaptive_work_group_size(v_output.extents()),
+ local_size,
// fence handle
VK_NULL_HANDLE,
// shader arguments
@@ -128,8 +63,34 @@
v_running_var.image(pipeline_barrier, api::PipelineStage::COMPUTE),
// params buffer
params.buffer());
+}
- return convert(v_output);
+} // namespace batchnorm
+
+namespace {
+
+using namespace api::utils;
+
+Tensor batch_norm(
+ const at::Tensor& input_arg,
+ const c10::optional<Tensor>& weight_opt /* optional */,
+ const c10::optional<Tensor>& bias_opt /* optional */,
+ const c10::optional<Tensor>& running_mean_opt /* optional */,
+ const c10::optional<Tensor>& running_var_opt /* optional */,
+ bool training,
+ double /* momentum, not used in eval mode */,
+ double eps,
+ bool /* cudnn_enable, deprecated */) {
+ TORCH_CHECK(!training, "Only evaluation mode is supported!");
+ TORCH_CHECK(input_arg.dim() == 4, "Input must have dim == 4!");
+ TORCH_CHECK(
+ get_dim<Dim4D::Channel>(input_arg) % 4 == 0,
+ "Input must have channels divisible by 4!");
+
+ return run_batchnorm_context(
+ input_arg,
+ c10::make_intrusive<BatchNormPackedContext>(BatchNormPackedContext(
+ weight_opt, bias_opt, running_mean_opt, running_var_opt, eps)));
}
#ifdef USE_VULKAN_API
@@ -141,6 +102,143 @@
#endif /* USE_VULKAN_API */
} // namespace
+
+BatchNormPackedContext::BatchNormPackedContext(
+ const c10::optional<Tensor>& weight_opt,
+ const c10::optional<Tensor>& bias_opt,
+ const c10::optional<Tensor>& running_mean_opt,
+ const c10::optional<Tensor>& running_var_opt,
+ double eps)
+ : unpacked_{c10::AnyType::get()} {
+ packed_.reserve(ListArgs::kNumArgs);
+
+ // Each optional tensor arg, if provided should be a 1 dimensional tensor. To
+ // achieve more efficient packing as a texture, they are first reshaped to {N,
+ // 1, 1}. Eventually this rearrangement should happen automatically in vTensor
+ // itself.
+
+ // Weight
+ TORCH_CHECK(weight_opt, "Weight must be provided!");
+ TORCH_CHECK(weight_opt->dim() == 1, "Weight must have ndim == 1!");
+
+ const int64_t num_features =
+ api::utils::safe_downcast<int64_t>(weight_opt->numel());
+ const Tensor weight_3d = weight_opt->reshape({num_features, 1, 1});
+ packed_.emplace_back(weight_3d.vulkan());
+
+ // Bias
+ TORCH_CHECK(bias_opt, "Bias must be provided!");
+ TORCH_CHECK(bias_opt->dim() == 1, "Bias must have ndim == 1!");
+ TORCH_CHECK(
+ bias_opt->numel() == num_features,
+ "Bias must have the same numel as weight!");
+
+ const Tensor bias_3d = bias_opt->reshape({num_features, 1, 1});
+ packed_.emplace_back(bias_3d.vulkan());
+
+ // Running Mean
+ TORCH_CHECK(running_mean_opt, "Running mean must be provided!");
+ TORCH_CHECK(running_mean_opt->dim() == 1, "Running mean must have ndim == 1");
+ TORCH_CHECK(
+ running_mean_opt->numel() == num_features,
+ "Running mean must have the same numel as weight!");
+
+ const Tensor running_mean_3d =
+ running_mean_opt->reshape({num_features, 1, 1});
+ packed_.emplace_back(running_mean_3d.vulkan());
+
+ // Running var
+ TORCH_CHECK(running_var_opt, "Running var must be provided!");
+ TORCH_CHECK(running_var_opt->dim() == 1, "Running var must have ndim == 1");
+ TORCH_CHECK(
+ running_var_opt->numel() == num_features,
+ "Running var must have the same numel as weight!");
+
+ const Tensor running_var_3d = running_var_opt->reshape({num_features, 1, 1});
+ packed_.emplace_back(running_var_3d.vulkan());
+
+ // Epsilon
+ packed_.emplace_back(eps);
+
+ if (!at::globalContext().releaseWeightsWhenPrepacking()) {
+ unpacked_.reserve(ListArgs::kNumArgs);
+ unpacked_.emplace_back(weight_opt);
+ unpacked_.emplace_back(bias_opt);
+ unpacked_.emplace_back(running_mean_opt);
+ unpacked_.emplace_back(running_var_opt);
+ unpacked_.emplace_back(eps);
+ }
+}
+
+BatchNormPackedContext BatchNormPackedContext::pack(
+ c10::impl::GenericList unpacked) {
+ return BatchNormPackedContext(
+ get_optional_tensor(unpacked, ListArgs::kWeight),
+ get_optional_tensor(unpacked, ListArgs::kBias),
+ get_optional_tensor(unpacked, ListArgs::kRunningMean),
+ get_optional_tensor(unpacked, ListArgs::kRunningVar),
+ unpacked.get(ListArgs::kEps).toDouble());
+}
+
+c10::intrusive_ptr<BatchNormPackedContext> create_batchnorm_context(
+ c10::optional<Tensor>&& weight_opt,
+ c10::optional<Tensor>&& bias_opt,
+ c10::optional<Tensor>&& running_mean_opt,
+ c10::optional<Tensor>&& running_var_opt,
+ bool training,
+ double /* momentum */,
+ double eps,
+ bool /* cudnn_enable, deprecated */) {
+ return c10::make_intrusive<BatchNormPackedContext>(BatchNormPackedContext(
+ weight_opt, bias_opt, running_mean_opt, running_var_opt, eps));
+}
+
+Tensor run_batchnorm_context(
+ const Tensor& input_arg,
+ const c10::intrusive_ptr<BatchNormPackedContext>& batchnorm_context) {
+ api::Context* const context = api::context();
+
+ const vTensor& v_input = convert(input_arg);
+
+ const vTensor& v_weight = convert(
+ batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kWeight)
+ .toTensor());
+
+ const vTensor& v_bias = convert(
+ batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kBias)
+ .toTensor());
+
+ const vTensor& v_running_mean = convert(
+ batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kRunningMean)
+ .toTensor());
+
+ const vTensor& v_running_var = convert(
+ batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kRunningVar)
+ .toTensor());
+
+ const float eps = api::utils::safe_downcast<float>(
+ batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kEps)
+ .toDouble());
+
+ vTensor v_output{
+ context,
+ v_input.sizes(),
+ v_input.options(),
+ };
+
+ batchnorm::record_op(
+ context,
+ v_output,
+ v_input,
+ v_weight,
+ v_bias,
+ v_running_mean,
+ v_running_var,
+ eps);
+
+ return convert(v_output);
+}
+
} // namespace ops
} // namespace vulkan
} // namespace native
diff --git a/aten/src/ATen/native/vulkan/ops/Batchnorm.h b/aten/src/ATen/native/vulkan/ops/Batchnorm.h
new file mode 100644
index 0000000..6afaeb6
--- /dev/null
+++ b/aten/src/ATen/native/vulkan/ops/Batchnorm.h
@@ -0,0 +1,68 @@
+#pragma once
+
+#ifdef USE_VULKAN_API
+
+#include <ATen/native/vulkan/ops/Common.h>
+#include <ATen/native/vulkan/ops/VulkanPackedContext.h>
+#include <torch/library.h>
+
+namespace at {
+namespace native {
+namespace vulkan {
+namespace ops {
+
+class BatchNormPackedContext final : virtual public VulkanPackedContext,
+ public torch::jit::CustomClassHolder {
+ private:
+ c10::impl::GenericList unpacked_;
+
+ public:
+ BatchNormPackedContext(
+ const c10::optional<Tensor>& weight_opt,
+ const c10::optional<Tensor>& bias_opt,
+ const c10::optional<Tensor>& running_mean_opt,
+ const c10::optional<Tensor>& running_var_opt,
+ double eps);
+
+ /*
+ * Assigns a name to each index in the packed/unpacked list.
+ */
+ struct ListArgs final {
+ static constexpr uint32_t kWeight = 0u;
+ static constexpr uint32_t kBias = 1u;
+ static constexpr uint32_t kRunningMean = 2u;
+ static constexpr uint32_t kRunningVar = 3u;
+ static constexpr uint32_t kEps = 4u;
+
+ static constexpr uint32_t kNumArgs = 5u;
+ };
+
+ static BatchNormPackedContext pack(c10::impl::GenericList);
+
+ const c10::impl::GenericList unpack() const override {
+ TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
+
+ return unpacked_;
+ }
+};
+
+c10::intrusive_ptr<BatchNormPackedContext> create_batchnorm_context(
+ c10::optional<Tensor>&& weight_opt,
+ c10::optional<Tensor>&& bias_opt,
+ c10::optional<Tensor>&& running_mean_opt,
+ c10::optional<Tensor>&& running_var_opt,
+ bool training,
+ double /* momentum */,
+ double eps,
+ bool /* cudnn_enable, deprecated */);
+
+Tensor run_batchnorm_context(
+ const Tensor& input_arg,
+ const c10::intrusive_ptr<BatchNormPackedContext>& context);
+
+} // namespace ops
+} // namespace vulkan
+} // namespace native
+} // namespace at
+
+#endif /* USE_VULKAN_API */
diff --git a/aten/src/ATen/native/vulkan/ops/Register.cpp b/aten/src/ATen/native/vulkan/ops/Register.cpp
index 18d5a6f..25f0a6d 100644
--- a/aten/src/ATen/native/vulkan/ops/Register.cpp
+++ b/aten/src/ATen/native/vulkan/ops/Register.cpp
@@ -1,5 +1,6 @@
#ifdef USE_VULKAN_API
+#include <ATen/native/vulkan/ops/Batchnorm.h>
#include <ATen/native/vulkan/ops/Common.h>
#include <ATen/native/vulkan/ops/Convolution.h>
#include <ATen/native/vulkan/ops/Gru.h>
@@ -16,6 +17,19 @@
namespace {
TORCH_LIBRARY(vulkan, m) {
+ m.class_<BatchNormPackedContext>("BatchNormPackedContext")
+ .def_pickle(
+ // __getstate__
+ [](const c10::intrusive_ptr<BatchNormPackedContext>& context) {
+ // context is packed
+ return context->unpack();
+ },
+ // __setstate__
+ [](c10::impl::GenericList state) {
+ // state is unpacked
+ return c10::make_intrusive<BatchNormPackedContext>(
+ BatchNormPackedContext::pack(state));
+ });
m.class_<LinearPackedContext>("LinearPackedContext")
.def_pickle(
// __getstate__
@@ -147,6 +161,22 @@
"Tensor hx_vk, "
"Tensor cx_vk, "
"__torch__.torch.classes.vulkan.LstmPackedContext L_prepack) -> (Tensor next_input, Tensor hidden_state, Tensor cell_state)"));
+ m.def(TORCH_SELECTIVE_SCHEMA(
+ "vulkan_prepack::create_batchnorm_context("
+ "Tensor? weight_opt, "
+ "Tensor? bias_opt, "
+ "Tensor? running_mean_opt, "
+ "Tensor? running_var_opt, "
+ "bool training, "
+ "float momentum, "
+ "float eps, "
+ "bool cudnn_enable) "
+ "-> __torch__.torch.classes.vulkan.BatchNormPackedContext"));
+ m.def(TORCH_SELECTIVE_SCHEMA(
+ "vulkan_prepack::run_batchnorm_context("
+ "Tensor input_vk, "
+ "__torch__.torch.classes.vulkan.BatchNormPackedContext context) "
+ "-> Tensor out"));
}
TORCH_LIBRARY_IMPL(vulkan_prepack, CPU, m) {
@@ -168,6 +198,9 @@
m.impl(
TORCH_SELECTIVE_NAME("vulkan_prepack::create_lstm_context"),
TORCH_FN(create_lstm_context));
+ m.impl(
+ TORCH_SELECTIVE_NAME("vulkan_prepack::create_batchnorm_context"),
+ TORCH_FN(create_batchnorm_context));
}
TORCH_LIBRARY_IMPL(vulkan_prepack, Vulkan, m) {
@@ -189,6 +222,9 @@
m.impl(
TORCH_SELECTIVE_NAME("vulkan_prepack::run_lstm_context"),
TORCH_FN(run_lstm_context));
+ m.impl(
+ TORCH_SELECTIVE_NAME("vulkan_prepack::run_batchnorm_context"),
+ TORCH_FN(run_batchnorm_context));
}
} // namespace
diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp
index 2519267..a9dc190 100644
--- a/aten/src/ATen/test/vulkan_api_test.cpp
+++ b/aten/src/ATen/test/vulkan_api_test.cpp
@@ -630,7 +630,7 @@
at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
- true,
+ false,
0.1,
1e-05,
false);
@@ -644,7 +644,7 @@
at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
- true,
+ false,
0.1,
1e-05,
false);
@@ -658,7 +658,7 @@
at::rand({7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
at::rand({7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
at::rand({7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
- true,
+ false,
0.1,
1e-05,
false);
@@ -672,7 +672,7 @@
at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
- true,
+ false,
0.1,
1e-05,
false);
@@ -686,7 +686,7 @@
at::rand({12}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
- true,
+ false,
0.1,
1e-05,
false);
@@ -700,7 +700,7 @@
at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
at::rand({12}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
- true,
+ false,
0.1,
1e-05,
false);
@@ -714,7 +714,7 @@
at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
at::rand({12}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
- true,
+ false,
0.1,
1e-05,
false);
diff --git a/torch/csrc/jit/passes/vulkan_rewrite.cpp b/torch/csrc/jit/passes/vulkan_rewrite.cpp
index cbfa612..0c37d5b 100644
--- a/torch/csrc/jit/passes/vulkan_rewrite.cpp
+++ b/torch/csrc/jit/passes/vulkan_rewrite.cpp
@@ -19,6 +19,24 @@
namespace {
+void insertPrePackedBatchNormOp(std::shared_ptr<Graph>& graph) {
+ std::string batchnorm_pattern = R"(
+ graph(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable):
+ %r = aten::batch_norm(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable)
+ return (%r))";
+ std::string prepacked_ops_pattern = R"(
+ graph(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable):
+ %op_context : __torch__.torch.classes.vulkan.BatchNormPackedContext = vulkan_prepack::create_batchnorm_context(
+ %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable)
+ %res = vulkan_prepack::run_batchnorm_context(%input, %op_context)
+ return (%res))";
+
+ SubgraphRewriter batchnorm_rewriter;
+ batchnorm_rewriter.RegisterRewritePattern(
+ batchnorm_pattern, prepacked_ops_pattern);
+ batchnorm_rewriter.runOnGraph(graph);
+}
+
void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
// fuse decomposed linear into aten::linear
FuseLinear(graph);
@@ -265,6 +283,7 @@
insertPrePackedConv2dOp(graph);
insertPrePackedGruOp(graph);
insertPrePackedLstmOp(graph);
+ insertPrePackedBatchNormOp(graph);
}
void vulkanInsertPrePackedOps(script::Module& module) {
@@ -295,7 +314,9 @@
(n->kind() ==
Symbol::fromQualString("vulkan_prepack::create_gru_context")) ||
(n->kind() ==
- Symbol::fromQualString("vulkan_prepack::create_lstm_context")));
+ Symbol::fromQualString("vulkan_prepack::create_lstm_context")) ||
+ (n->kind() ==
+ Symbol::fromQualString("vulkan_prepack::create_batchnorm_context")));
};
PrePackingOpsFolder(m, filter_fn, "prepack_folding");
}
@@ -320,18 +341,20 @@
auto cloned_module = m.clone();
cloned_module.eval();
cloned_module = FoldConvBatchNorm(cloned_module);
- vulkanInsertPrePackedOps(cloned_module);
cloned_module = freeze_module(cloned_module, preserved_methods);
+ vulkanInsertPrePackedOps(cloned_module);
+ vulkanFusePrePackedConvWithClamp(cloned_module);
+ vulkanFoldPrePackingOps(cloned_module);
+ removeDropout(cloned_module);
+ vulkanRemoveMutation(cloned_module);
+
if (!optimization_blocklist.count(
MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER)) {
transferInputOutputBackends(cloned_module);
cloned_module.register_attribute(
"requires_backend_transfers", BoolType::get(), false);
}
- vulkanFusePrePackedConvWithClamp(cloned_module);
- vulkanFoldPrePackingOps(cloned_module);
- removeDropout(cloned_module);
- vulkanRemoveMutation(cloned_module);
+
// remove duplicated constants
vulkanRunCanonicalOptimizations(cloned_module);
eliminateDeadCode(cloned_module);