[vulkan] enable prepacking for Batchnorm op (#88433)

Adds a `BatchNormPackedContext` so that the `batchnorm` op can use prepacking.

Differential Revision: [D40721546](https://our.internmc.facebook.com/intern/diff/D40721546/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88433
Approved by: https://github.com/manuelcandales
diff --git a/aten/src/ATen/native/vulkan/glsl/batchnorm.glsl b/aten/src/ATen/native/vulkan/glsl/batchnorm.glsl
index 6ec9342..0ec7dbd 100644
--- a/aten/src/ATen/native/vulkan/glsl/batchnorm.glsl
+++ b/aten/src/ATen/native/vulkan/glsl/batchnorm.glsl
@@ -1,37 +1,61 @@
 #version 450 core
 #define PRECISION $precision
-#define FORMAT    $format
+#define FORMAT $format
 
 layout(std430) buffer;
 
-/* Qualifiers: layout - storage - precision - memory */
+/*
+ * Output Image
+ */
+layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
 
-layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D   uOutput;
-layout(set = 0, binding = 1)         uniform PRECISION                    sampler3D uInput;
-layout(set = 0, binding = 2)         uniform PRECISION                    sampler3D uGamma;
-layout(set = 0, binding = 3)         uniform PRECISION                    sampler3D uBeta;
-layout(set = 0, binding = 4)         uniform PRECISION                    sampler3D uMean;
-layout(set = 0, binding = 5)         uniform PRECISION                    sampler3D uVar;
-layout(set = 0, binding = 6)         uniform PRECISION restrict           Block {
-  ivec3 isize;
-  int channels_ext;
+/*
+ * Input Textures
+ */
+layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
+layout(set = 0, binding = 2) uniform PRECISION sampler3D uGamma;
+layout(set = 0, binding = 3) uniform PRECISION sampler3D uBeta;
+layout(set = 0, binding = 4) uniform PRECISION sampler3D uMean;
+layout(set = 0, binding = 5) uniform PRECISION sampler3D uVar;
+
+/*
+ * Params Buffer
+ */
+layout(set = 0, binding = 6) uniform PRECISION restrict Block {
+  // xyz contains extents of the output texture, w contains the number of
+  // channels divided by 4, rounded up.
+  ivec4 out_extents;
   float eps;
-} uBlock;
+}
+uBlock;
 
+/*
+ * Local Work Group
+ */
 layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
 
+/*
+ * Computes a Batch normalization. Each shader invocation calculates the output
+ * at a single output location.
+ */
 void main() {
   const ivec3 pos = ivec3(gl_GlobalInvocationID);
 
-  if (all(lessThan(pos, uBlock.isize.xyz))) {
-    const ivec3 chn = ivec3(0, 0, pos.z % uBlock.channels_ext);
-    imageStore(
-        uOutput,
-        pos,
-        (texelFetch(uInput, pos, 0)
-            - texelFetch(uMean, chn, 0))
-            / sqrt(texelFetch(uVar, chn, 0) + uBlock.eps)
-            * texelFetch(uGamma, chn, 0)
-            + texelFetch(uBeta, chn, 0));
+  // Return if this global position is outside output texture bounds
+  if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) {
+    return;
   }
+
+  const ivec3 ch_pos = ivec3(0, 0, pos.z % uBlock.out_extents.w);
+
+  const vec4 in_tex = texelFetch(uInput, pos, 0);
+  const vec4 gamma_tex = texelFetch(uGamma, ch_pos, 0);
+  const vec4 beta_tex = texelFetch(uBeta, ch_pos, 0);
+  const vec4 mean_tex = texelFetch(uMean, ch_pos, 0);
+  const vec4 var_tex = texelFetch(uVar, ch_pos, 0);
+
+  const vec4 out_tex =
+      (in_tex - mean_tex) / sqrt(var_tex + uBlock.eps) * gamma_tex + beta_tex;
+
+  imageStore(uOutput, pos, out_tex);
 }
diff --git a/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp b/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp
index 84828aa..d1fecca 100644
--- a/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp
+++ b/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp
@@ -1,109 +1,44 @@
-#include <ATen/native/vulkan/ops/Common.h>
+#include <ATen/Context.h>
+#include <ATen/native/vulkan/ops/Batchnorm.h>
 #include <torch/library.h>
 
 namespace at {
 namespace native {
 namespace vulkan {
 namespace ops {
-namespace {
 
-using namespace api::utils;
+namespace batchnorm {
 
-Tensor batch_norm(
-    const at::Tensor& input_arg,
-    const c10::optional<Tensor>& weight_opt /* optional */,
-    const c10::optional<Tensor>& bias_opt /* optional */,
-    const c10::optional<Tensor>& running_mean_opt /* optional */,
-    const c10::optional<Tensor>& running_var_opt /* optional */,
-    bool training,
-    double /* momentum, not used in eval mode */,
-    double eps,
-    bool /* cudnn_enable, deprecated */) {
-  TORCH_CHECK(!training, "Vulkan batchnorm only supports evaluation mode.");
-  TORCH_CHECK(
-      weight_opt && weight_opt->defined() && bias_opt && bias_opt->defined(),
-      "Vulkan batchnorm expects weight and bias arguments to be defined");
-  TORCH_CHECK(
-      running_mean_opt && running_mean_opt->defined(),
-      "running_mean must be defined in evaluation mode.");
-  TORCH_CHECK(
-      running_var_opt && running_var_opt->defined(),
-      "running_var must be defined in evaluation mode.");
-  TORCH_CHECK(input_arg.dim() == 4, "Vulkan batchnorm expects 4-dim input!");
-  TORCH_CHECK(
-      get_dim<Dim4D::Channel>(input_arg) % 4 == 0,
-      "Vulkan batchnorm expects channel dim to be multiple of 4!");
+struct Params final {
+  api::utils::ivec3 out_extents;
+  int32_t c4;
+  float eps;
+};
 
-  const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
-  const vTensor& v_input = convert(input);
-  const IntArrayRef v_input_sizes = v_input.sizes();
+void record_op(
+    api::Context* const context,
+    vTensor& v_output,
+    const vTensor& v_input,
+    const vTensor& v_weight,
+    const vTensor& v_bias,
+    const vTensor& v_running_mean,
+    const vTensor& v_running_var,
+    const float eps) {
+  api::PipelineBarrier pipeline_barrier{};
 
-  auto num_features = v_input.sizes()[1];
-  auto channels_ext = num_features / 4;
+  api::utils::uvec3 global_size = v_output.extents();
+  api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
 
-  const Tensor weight_opt_3d = weight_opt->reshape({num_features, 1, 1});
-  const Tensor weight =
-      weight_opt_3d.is_vulkan() ? weight_opt_3d : weight_opt_3d.vulkan();
-  const vTensor& v_weight = convert(weight);
-  TORCH_CHECK(
-      weight.numel() == num_features,
-      "weight tensor should contain ",
-      num_features,
-      " elements!");
+  uint32_t num_features = get_dim<Dim4D::Channel>(v_input.sizes());
+  uint32_t channels_ext = api::utils::div_up(num_features, 4u);
 
-  const Tensor bias_opt_3d = bias_opt->reshape({num_features, 1, 1});
-  const Tensor bias =
-      bias_opt_3d.is_vulkan() ? bias_opt_3d : bias_opt_3d.vulkan();
-  const vTensor& v_bias = convert(bias);
-  TORCH_CHECK(
-      bias.numel() == num_features,
-      "bias tensor should contain ",
-      num_features,
-      " elements!");
-
-  const Tensor running_mean_opt_3d =
-      running_mean_opt->reshape({num_features, 1, 1});
-  const Tensor running_mean = running_mean_opt_3d.is_vulkan()
-      ? running_mean_opt_3d
-      : running_mean_opt_3d.vulkan();
-  const vTensor& v_running_mean = convert(running_mean);
-  TORCH_CHECK(
-      running_mean.numel() == num_features,
-      "running mean tensor should contain ",
-      num_features,
-      " elements!");
-
-  const Tensor running_var_opt_3d =
-      running_var_opt->reshape({num_features, 1, 1});
-  const Tensor running_var = running_var_opt_3d.is_vulkan()
-      ? running_var_opt_3d
-      : running_var_opt_3d.vulkan();
-  const vTensor& v_running_var = convert(running_var);
-  TORCH_CHECK(
-      running_var.numel() == num_features,
-      "running var tensor should contain ",
-      num_features,
-      " elements!");
-
-  api::Context* const context = api::context();
-
-  vTensor v_output{
-      context,
-      v_input_sizes,
-      v_input.options(),
+  Params block{
+      api::utils::make_ivec3(v_output.extents()),
+      api::utils::safe_downcast<int32_t>(channels_ext),
+      eps,
   };
 
-  const struct Block final {
-    uvec3 iextents;
-    int32_t channels_ext;
-    float epsilon;
-  } block{
-      v_output.extents(),
-      safe_downcast<int32_t>(channels_ext),
-      safe_downcast<float>(eps)};
-
   api::UniformParamsBuffer params(context, block);
-  api::PipelineBarrier pipeline_barrier{};
 
   context->submit_compute_job(
       // shader descriptor
@@ -111,9 +46,9 @@
       // pipeline barrier
       pipeline_barrier,
       // global work group size
-      v_output.extents(),
+      global_size,
       // local work group size
-      adaptive_work_group_size(v_output.extents()),
+      local_size,
       // fence handle
       VK_NULL_HANDLE,
       // shader arguments
@@ -128,8 +63,34 @@
       v_running_var.image(pipeline_barrier, api::PipelineStage::COMPUTE),
       // params buffer
       params.buffer());
+}
 
-  return convert(v_output);
+} // namespace batchnorm
+
+namespace {
+
+using namespace api::utils;
+
+Tensor batch_norm(
+    const at::Tensor& input_arg,
+    const c10::optional<Tensor>& weight_opt /* optional */,
+    const c10::optional<Tensor>& bias_opt /* optional */,
+    const c10::optional<Tensor>& running_mean_opt /* optional */,
+    const c10::optional<Tensor>& running_var_opt /* optional */,
+    bool training,
+    double /* momentum, not used in eval mode */,
+    double eps,
+    bool /* cudnn_enable, deprecated */) {
+  TORCH_CHECK(!training, "Only evaluation mode is supported!");
+  TORCH_CHECK(input_arg.dim() == 4, "Input must have dim == 4!");
+  TORCH_CHECK(
+      get_dim<Dim4D::Channel>(input_arg) % 4 == 0,
+      "Input must have channels divisible by 4!");
+
+  return run_batchnorm_context(
+      input_arg,
+      c10::make_intrusive<BatchNormPackedContext>(BatchNormPackedContext(
+          weight_opt, bias_opt, running_mean_opt, running_var_opt, eps)));
 }
 
 #ifdef USE_VULKAN_API
@@ -141,6 +102,143 @@
 #endif /* USE_VULKAN_API */
 
 } // namespace
+
+BatchNormPackedContext::BatchNormPackedContext(
+    const c10::optional<Tensor>& weight_opt,
+    const c10::optional<Tensor>& bias_opt,
+    const c10::optional<Tensor>& running_mean_opt,
+    const c10::optional<Tensor>& running_var_opt,
+    double eps)
+    : unpacked_{c10::AnyType::get()} {
+  packed_.reserve(ListArgs::kNumArgs);
+
+  // Each optional tensor arg, if provided should be a 1 dimensional tensor. To
+  // achieve more efficient packing as a texture, they are first reshaped to {N,
+  // 1, 1}. Eventually this rearrangement should happen automatically in vTensor
+  // itself.
+
+  // Weight
+  TORCH_CHECK(weight_opt, "Weight must be provided!");
+  TORCH_CHECK(weight_opt->dim() == 1, "Weight must have ndim == 1!");
+
+  const int64_t num_features =
+      api::utils::safe_downcast<int64_t>(weight_opt->numel());
+  const Tensor weight_3d = weight_opt->reshape({num_features, 1, 1});
+  packed_.emplace_back(weight_3d.vulkan());
+
+  // Bias
+  TORCH_CHECK(bias_opt, "Bias must be provided!");
+  TORCH_CHECK(bias_opt->dim() == 1, "Bias must have ndim == 1!");
+  TORCH_CHECK(
+      bias_opt->numel() == num_features,
+      "Bias must have the same numel as weight!");
+
+  const Tensor bias_3d = bias_opt->reshape({num_features, 1, 1});
+  packed_.emplace_back(bias_3d.vulkan());
+
+  // Running Mean
+  TORCH_CHECK(running_mean_opt, "Running mean must be provided!");
+  TORCH_CHECK(running_mean_opt->dim() == 1, "Running mean must have ndim == 1");
+  TORCH_CHECK(
+      running_mean_opt->numel() == num_features,
+      "Running mean must have the same numel as weight!");
+
+  const Tensor running_mean_3d =
+      running_mean_opt->reshape({num_features, 1, 1});
+  packed_.emplace_back(running_mean_3d.vulkan());
+
+  // Running var
+  TORCH_CHECK(running_var_opt, "Running var must be provided!");
+  TORCH_CHECK(running_var_opt->dim() == 1, "Running var must have ndim == 1");
+  TORCH_CHECK(
+      running_var_opt->numel() == num_features,
+      "Running var must have the same numel as weight!");
+
+  const Tensor running_var_3d = running_var_opt->reshape({num_features, 1, 1});
+  packed_.emplace_back(running_var_3d.vulkan());
+
+  // Epsilon
+  packed_.emplace_back(eps);
+
+  if (!at::globalContext().releaseWeightsWhenPrepacking()) {
+    unpacked_.reserve(ListArgs::kNumArgs);
+    unpacked_.emplace_back(weight_opt);
+    unpacked_.emplace_back(bias_opt);
+    unpacked_.emplace_back(running_mean_opt);
+    unpacked_.emplace_back(running_var_opt);
+    unpacked_.emplace_back(eps);
+  }
+}
+
+BatchNormPackedContext BatchNormPackedContext::pack(
+    c10::impl::GenericList unpacked) {
+  return BatchNormPackedContext(
+      get_optional_tensor(unpacked, ListArgs::kWeight),
+      get_optional_tensor(unpacked, ListArgs::kBias),
+      get_optional_tensor(unpacked, ListArgs::kRunningMean),
+      get_optional_tensor(unpacked, ListArgs::kRunningVar),
+      unpacked.get(ListArgs::kEps).toDouble());
+}
+
+c10::intrusive_ptr<BatchNormPackedContext> create_batchnorm_context(
+    c10::optional<Tensor>&& weight_opt,
+    c10::optional<Tensor>&& bias_opt,
+    c10::optional<Tensor>&& running_mean_opt,
+    c10::optional<Tensor>&& running_var_opt,
+    bool training,
+    double /* momentum */,
+    double eps,
+    bool /* cudnn_enable, deprecated */) {
+  return c10::make_intrusive<BatchNormPackedContext>(BatchNormPackedContext(
+      weight_opt, bias_opt, running_mean_opt, running_var_opt, eps));
+}
+
+Tensor run_batchnorm_context(
+    const Tensor& input_arg,
+    const c10::intrusive_ptr<BatchNormPackedContext>& batchnorm_context) {
+  api::Context* const context = api::context();
+
+  const vTensor& v_input = convert(input_arg);
+
+  const vTensor& v_weight = convert(
+      batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kWeight)
+          .toTensor());
+
+  const vTensor& v_bias = convert(
+      batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kBias)
+          .toTensor());
+
+  const vTensor& v_running_mean = convert(
+      batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kRunningMean)
+          .toTensor());
+
+  const vTensor& v_running_var = convert(
+      batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kRunningVar)
+          .toTensor());
+
+  const float eps = api::utils::safe_downcast<float>(
+      batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kEps)
+          .toDouble());
+
+  vTensor v_output{
+      context,
+      v_input.sizes(),
+      v_input.options(),
+  };
+
+  batchnorm::record_op(
+      context,
+      v_output,
+      v_input,
+      v_weight,
+      v_bias,
+      v_running_mean,
+      v_running_var,
+      eps);
+
+  return convert(v_output);
+}
+
 } // namespace ops
 } // namespace vulkan
 } // namespace native
diff --git a/aten/src/ATen/native/vulkan/ops/Batchnorm.h b/aten/src/ATen/native/vulkan/ops/Batchnorm.h
new file mode 100644
index 0000000..6afaeb6
--- /dev/null
+++ b/aten/src/ATen/native/vulkan/ops/Batchnorm.h
@@ -0,0 +1,68 @@
+#pragma once
+
+#ifdef USE_VULKAN_API
+
+#include <ATen/native/vulkan/ops/Common.h>
+#include <ATen/native/vulkan/ops/VulkanPackedContext.h>
+#include <torch/library.h>
+
+namespace at {
+namespace native {
+namespace vulkan {
+namespace ops {
+
+class BatchNormPackedContext final : virtual public VulkanPackedContext,
+                                     public torch::jit::CustomClassHolder {
+ private:
+  c10::impl::GenericList unpacked_;
+
+ public:
+  BatchNormPackedContext(
+      const c10::optional<Tensor>& weight_opt,
+      const c10::optional<Tensor>& bias_opt,
+      const c10::optional<Tensor>& running_mean_opt,
+      const c10::optional<Tensor>& running_var_opt,
+      double eps);
+
+  /*
+   * Assigns a name to each index in the packed/unpacked list.
+   */
+  struct ListArgs final {
+    static constexpr uint32_t kWeight = 0u;
+    static constexpr uint32_t kBias = 1u;
+    static constexpr uint32_t kRunningMean = 2u;
+    static constexpr uint32_t kRunningVar = 3u;
+    static constexpr uint32_t kEps = 4u;
+
+    static constexpr uint32_t kNumArgs = 5u;
+  };
+
+  static BatchNormPackedContext pack(c10::impl::GenericList);
+
+  const c10::impl::GenericList unpack() const override {
+    TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
+
+    return unpacked_;
+  }
+};
+
+c10::intrusive_ptr<BatchNormPackedContext> create_batchnorm_context(
+    c10::optional<Tensor>&& weight_opt,
+    c10::optional<Tensor>&& bias_opt,
+    c10::optional<Tensor>&& running_mean_opt,
+    c10::optional<Tensor>&& running_var_opt,
+    bool training,
+    double /* momentum */,
+    double eps,
+    bool /* cudnn_enable, deprecated */);
+
+Tensor run_batchnorm_context(
+    const Tensor& input_arg,
+    const c10::intrusive_ptr<BatchNormPackedContext>& context);
+
+} // namespace ops
+} // namespace vulkan
+} // namespace native
+} // namespace at
+
+#endif /* USE_VULKAN_API */
diff --git a/aten/src/ATen/native/vulkan/ops/Register.cpp b/aten/src/ATen/native/vulkan/ops/Register.cpp
index 18d5a6f..25f0a6d 100644
--- a/aten/src/ATen/native/vulkan/ops/Register.cpp
+++ b/aten/src/ATen/native/vulkan/ops/Register.cpp
@@ -1,5 +1,6 @@
 #ifdef USE_VULKAN_API
 
+#include <ATen/native/vulkan/ops/Batchnorm.h>
 #include <ATen/native/vulkan/ops/Common.h>
 #include <ATen/native/vulkan/ops/Convolution.h>
 #include <ATen/native/vulkan/ops/Gru.h>
@@ -16,6 +17,19 @@
 namespace {
 
 TORCH_LIBRARY(vulkan, m) {
+  m.class_<BatchNormPackedContext>("BatchNormPackedContext")
+      .def_pickle(
+          // __getstate__
+          [](const c10::intrusive_ptr<BatchNormPackedContext>& context) {
+            // context is packed
+            return context->unpack();
+          },
+          // __setstate__
+          [](c10::impl::GenericList state) {
+            // state is unpacked
+            return c10::make_intrusive<BatchNormPackedContext>(
+                BatchNormPackedContext::pack(state));
+          });
   m.class_<LinearPackedContext>("LinearPackedContext")
       .def_pickle(
           // __getstate__
@@ -147,6 +161,22 @@
       "Tensor hx_vk, "
       "Tensor cx_vk, "
       "__torch__.torch.classes.vulkan.LstmPackedContext L_prepack) -> (Tensor next_input, Tensor hidden_state, Tensor cell_state)"));
+  m.def(TORCH_SELECTIVE_SCHEMA(
+      "vulkan_prepack::create_batchnorm_context("
+      "Tensor? weight_opt, "
+      "Tensor? bias_opt, "
+      "Tensor? running_mean_opt, "
+      "Tensor? running_var_opt, "
+      "bool training, "
+      "float momentum, "
+      "float eps, "
+      "bool cudnn_enable) "
+      "-> __torch__.torch.classes.vulkan.BatchNormPackedContext"));
+  m.def(TORCH_SELECTIVE_SCHEMA(
+      "vulkan_prepack::run_batchnorm_context("
+      "Tensor input_vk, "
+      "__torch__.torch.classes.vulkan.BatchNormPackedContext context) "
+      "-> Tensor out"));
 }
 
 TORCH_LIBRARY_IMPL(vulkan_prepack, CPU, m) {
@@ -168,6 +198,9 @@
   m.impl(
       TORCH_SELECTIVE_NAME("vulkan_prepack::create_lstm_context"),
       TORCH_FN(create_lstm_context));
+  m.impl(
+      TORCH_SELECTIVE_NAME("vulkan_prepack::create_batchnorm_context"),
+      TORCH_FN(create_batchnorm_context));
 }
 
 TORCH_LIBRARY_IMPL(vulkan_prepack, Vulkan, m) {
@@ -189,6 +222,9 @@
   m.impl(
       TORCH_SELECTIVE_NAME("vulkan_prepack::run_lstm_context"),
       TORCH_FN(run_lstm_context));
+  m.impl(
+      TORCH_SELECTIVE_NAME("vulkan_prepack::run_batchnorm_context"),
+      TORCH_FN(run_batchnorm_context));
 }
 
 } // namespace
diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp
index 2519267..a9dc190 100644
--- a/aten/src/ATen/test/vulkan_api_test.cpp
+++ b/aten/src/ATen/test/vulkan_api_test.cpp
@@ -630,7 +630,7 @@
       at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
       at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
       at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
-      true,
+      false,
       0.1,
       1e-05,
       false);
@@ -644,7 +644,7 @@
       at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
       at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
       at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
-      true,
+      false,
       0.1,
       1e-05,
       false);
@@ -658,7 +658,7 @@
       at::rand({7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
       at::rand({7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
       at::rand({7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
-      true,
+      false,
       0.1,
       1e-05,
       false);
@@ -672,7 +672,7 @@
       at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
       at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
       at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
-      true,
+      false,
       0.1,
       1e-05,
       false);
@@ -686,7 +686,7 @@
       at::rand({12}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
       at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
       at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
-      true,
+      false,
       0.1,
       1e-05,
       false);
@@ -700,7 +700,7 @@
       at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
       at::rand({12}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
       at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
-      true,
+      false,
       0.1,
       1e-05,
       false);
@@ -714,7 +714,7 @@
       at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
       at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
       at::rand({12}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(),
-      true,
+      false,
       0.1,
       1e-05,
       false);
diff --git a/torch/csrc/jit/passes/vulkan_rewrite.cpp b/torch/csrc/jit/passes/vulkan_rewrite.cpp
index cbfa612..0c37d5b 100644
--- a/torch/csrc/jit/passes/vulkan_rewrite.cpp
+++ b/torch/csrc/jit/passes/vulkan_rewrite.cpp
@@ -19,6 +19,24 @@
 
 namespace {
 
+void insertPrePackedBatchNormOp(std::shared_ptr<Graph>& graph) {
+  std::string batchnorm_pattern = R"(
+    graph(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable):
+        %r = aten::batch_norm(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable)
+        return (%r))";
+  std::string prepacked_ops_pattern = R"(
+    graph(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable):
+        %op_context : __torch__.torch.classes.vulkan.BatchNormPackedContext = vulkan_prepack::create_batchnorm_context(
+            %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable)
+        %res = vulkan_prepack::run_batchnorm_context(%input, %op_context)
+        return (%res))";
+
+  SubgraphRewriter batchnorm_rewriter;
+  batchnorm_rewriter.RegisterRewritePattern(
+      batchnorm_pattern, prepacked_ops_pattern);
+  batchnorm_rewriter.runOnGraph(graph);
+}
+
 void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
   // fuse decomposed linear into aten::linear
   FuseLinear(graph);
@@ -265,6 +283,7 @@
   insertPrePackedConv2dOp(graph);
   insertPrePackedGruOp(graph);
   insertPrePackedLstmOp(graph);
+  insertPrePackedBatchNormOp(graph);
 }
 
 void vulkanInsertPrePackedOps(script::Module& module) {
@@ -295,7 +314,9 @@
         (n->kind() ==
          Symbol::fromQualString("vulkan_prepack::create_gru_context")) ||
         (n->kind() ==
-         Symbol::fromQualString("vulkan_prepack::create_lstm_context")));
+         Symbol::fromQualString("vulkan_prepack::create_lstm_context")) ||
+        (n->kind() ==
+         Symbol::fromQualString("vulkan_prepack::create_batchnorm_context")));
   };
   PrePackingOpsFolder(m, filter_fn, "prepack_folding");
 }
@@ -320,18 +341,20 @@
   auto cloned_module = m.clone();
   cloned_module.eval();
   cloned_module = FoldConvBatchNorm(cloned_module);
-  vulkanInsertPrePackedOps(cloned_module);
   cloned_module = freeze_module(cloned_module, preserved_methods);
+  vulkanInsertPrePackedOps(cloned_module);
+  vulkanFusePrePackedConvWithClamp(cloned_module);
+  vulkanFoldPrePackingOps(cloned_module);
+  removeDropout(cloned_module);
+  vulkanRemoveMutation(cloned_module);
+
   if (!optimization_blocklist.count(
           MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER)) {
     transferInputOutputBackends(cloned_module);
     cloned_module.register_attribute(
         "requires_backend_transfers", BoolType::get(), false);
   }
-  vulkanFusePrePackedConvWithClamp(cloned_module);
-  vulkanFoldPrePackingOps(cloned_module);
-  removeDropout(cloned_module);
-  vulkanRemoveMutation(cloned_module);
+
   // remove duplicated constants
   vulkanRunCanonicalOptimizations(cloned_module);
   eliminateDeadCode(cloned_module);