[vulkan] jit passes for vulkan conv2 prepack and fuse with clamp (#39282)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39282

Test Plan: Imported from OSS

Differential Revision: D21962424

Pulled By: IvanKobzarev

fbshipit-source-id: 2d20e827d2c3836b7e6b443293377c68dc1ffa5a
diff --git a/aten/src/ATen/native/vulkan/VulkanAten.cpp b/aten/src/ATen/native/vulkan/VulkanAten.cpp
index bd71f82..a209e71 100644
--- a/aten/src/ATen/native/vulkan/VulkanAten.cpp
+++ b/aten/src/ATen/native/vulkan/VulkanAten.cpp
@@ -179,7 +179,7 @@
       voutput,
       vinput,
       weight.data_ptr<float>(),
-      bias.defined() ? c10::make_optional<float*>(bias.data_ptr<float>())
+      bias.defined() ? c10::make_optional<const float*>(bias.data_ptr<float>())
                      : c10::nullopt,
       params);
   return new_with_vtensor_vulkan(std::move(voutput), input.options());
@@ -242,7 +242,8 @@
         voutput,
         vinput,
         vweight,
-        hasBias ? c10::make_optional((*bias).data_ptr<float>()) : c10::nullopt,
+        hasBias ? c10::make_optional<const float*>((*bias).data_ptr<float>())
+                : c10::nullopt,
         params,
         output_min,
         output_max);
diff --git a/aten/src/ATen/native/vulkan/VulkanConvolution.cpp b/aten/src/ATen/native/vulkan/VulkanConvolution.cpp
index a3128ff..ef8c18a 100644
--- a/aten/src/ATen/native/vulkan/VulkanConvolution.cpp
+++ b/aten/src/ATen/native/vulkan/VulkanConvolution.cpp
@@ -66,14 +66,13 @@
   const auto stride_expanded = expand_param_if_needed(stride, "stride", 2);
   const auto dilation_expanded =
       expand_param_if_needed(dilation, "dilation", 2);
-  const Tensor weight_nchw = weight.contiguous();
+  Tensor weight_nchw = weight.contiguous();
+  auto ws = weight_nchw.sizes();
   return ContextConv2D{
-      at::native::vulkan_convolution_prepack_weights(weight),
+      groups == 1 ? at::native::vulkan_convolution_prepack_weights(weight_nchw)
+                  : weight_nchw.vulkan(),
       bias.has_value() ? c10::make_optional((*bias).vulkan()) : c10::nullopt,
-      {weight_nchw.sizes()[0],
-       weight_nchw.sizes()[1],
-       weight_nchw.sizes()[2],
-       weight_nchw.sizes()[3]},
+      {{ws[0], ws[1], ws[2], ws[3]}},
       {padding_expanded[0], padding_expanded[1]},
       {stride_expanded[0], stride_expanded[1]},
       {dilation_expanded[0], dilation_expanded[1]},
diff --git a/aten/src/ATen/native/vulkan/VulkanOps.cpp b/aten/src/ATen/native/vulkan/VulkanOps.cpp
index 2dda64c..c69a70b 100644
--- a/aten/src/ATen/native/vulkan/VulkanOps.cpp
+++ b/aten/src/ATen/native/vulkan/VulkanOps.cpp
@@ -176,7 +176,7 @@
 }
 
 VBuffer bufferFromOptionalHostData(
-    c10::optional<float*> data,
+    c10::optional<const float*> data,
     const uint32_t size) {
   const auto sizeAligned =
       ROUND_UP(size, context().limits().minStorageBufferOffsetAlignment);
@@ -202,17 +202,15 @@
 void conv2d_depthwise(
     VulkanTensor& output,
     const VulkanTensor& input,
-    const float* weight,
-    const c10::optional<float*> bias,
-    const Conv2DParams params,
+    const VulkanTensor& weight,
+    const VBuffer& biasBuffer,
+    const Conv2DParams& params,
     c10::optional<float> output_min,
     c10::optional<float> output_max) {
   TORCH_INTERNAL_ASSERT(params.G == params.C);
   auto osizes = output.sizes();
   TORCH_INTERNAL_ASSERT(osizes[2] == params.OH);
   TORCH_INTERNAL_ASSERT(osizes[3] == params.OW);
-  auto biasBuffer =
-      bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC));
   struct ConstBlock {
     int32_t padding[2];
     int32_t kernelSize[2];
@@ -234,9 +232,6 @@
       output_max ? *output_max : std::numeric_limits<float>::infinity()};
   VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb));
 
-  VulkanTensor kernel{{params.OC, params.KH, params.KW}};
-  kernel.set_data_from_host(weight);
-
   VkDescriptorSetLayout descriptorSetLayout{};
   VkDescriptorPool descriptorPool{};
   VkDescriptorSet descriptorSet{};
@@ -256,7 +251,7 @@
 
   output.image()->bindStorageImage(descriptorSet, 0);
   input.image()->bindShaderRead(descriptorSet, 1);
-  kernel.image()->bindShaderRead(descriptorSet, 2);
+  weight.image()->bindShaderRead(descriptorSet, 2);
   biasBuffer.bind(descriptorSet, 3);
   constBuffer.bind(descriptorSet, 4);
 
@@ -269,7 +264,7 @@
   auto commandBuffer = computeUnit.commandBuffer();
   output.image()->addImageMemoryBarrierToGeneral(commandBuffer);
   input.image()->addImageMemoryBarrierToShaderRead(commandBuffer);
-  kernel.image()->addImageMemoryBarrierToShaderRead(commandBuffer);
+  weight.image()->addImageMemoryBarrierToShaderRead(commandBuffer);
   computeUnit.dispatchCommandBuffer(
       params.OW, params.OH, params.OC_4, workGroupSize);
   computeUnit.endCommandBuffer();
@@ -279,6 +274,44 @@
   vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr);
 }
 
+void conv2d_depthwise(
+    VulkanTensor& output,
+    const VulkanTensor& input,
+    const VulkanTensor& weight,
+    const c10::optional<const float*> bias,
+    const Conv2DParams params,
+    c10::optional<float> output_min,
+    c10::optional<float> output_max) {
+  conv2d_depthwise(
+      output,
+      input,
+      weight,
+      bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC)),
+      params,
+      output_min,
+      output_max);
+}
+
+void conv2d_depthwise(
+    VulkanTensor& output,
+    const VulkanTensor& input,
+    const float* weight,
+    const c10::optional<const float*> bias,
+    const Conv2DParams params,
+    c10::optional<float> output_min,
+    c10::optional<float> output_max) {
+  VulkanTensor weightTensor{{params.OC, params.KH, params.KW}};
+  weightTensor.set_data_from_host(weight);
+  conv2d_depthwise(
+      output,
+      input,
+      weightTensor,
+      bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC)),
+      params,
+      output_min,
+      output_max);
+}
+
 ImageSizes conv2d_prepack_weights_image_sizes(
     int64_t OC,
     int64_t C,
@@ -463,7 +496,7 @@
     VulkanTensor& output,
     const VulkanTensor& input,
     const VImage& kernelImage,
-    const c10::optional<float*> bias,
+    const c10::optional<const float*> bias,
     const Conv2DParams& params,
     c10::optional<float> output_min,
     c10::optional<float> output_max) {
@@ -483,10 +516,22 @@
     VulkanTensor& output,
     const VulkanTensor& input,
     const VulkanTensor& weight_prepacked,
-    c10::optional<float*> bias,
+    c10::optional<const float*> bias,
     const Conv2DParams params,
     c10::optional<float> output_min,
     c10::optional<float> output_max) {
+  if (params.G > 1) {
+    conv2d_depthwise(
+        output,
+        input,
+        weight_prepacked,
+        bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC)),
+        params,
+        output_min,
+        output_max);
+    return;
+  }
+
   conv2d(
       output,
       input,
@@ -505,6 +550,18 @@
     const Conv2DParams params,
     c10::optional<float> output_min,
     c10::optional<float> output_max) {
+  if (params.G > 1) {
+    conv2d_depthwise(
+        output,
+        input,
+        weight_prepacked,
+        *(bias.buffer()),
+        params,
+        output_min,
+        output_max);
+    return;
+  }
+
   conv2d(
       output,
       input,
@@ -519,7 +576,7 @@
     VulkanTensor& output,
     const VulkanTensor& input,
     const float* weight,
-    const c10::optional<float*> bias,
+    const c10::optional<const float*> bias,
     const Conv2DParams params,
     c10::optional<float> output_min,
     c10::optional<float> output_max) {
diff --git a/aten/src/ATen/native/vulkan/VulkanOps.h b/aten/src/ATen/native/vulkan/VulkanOps.h
index 969bd85..64b32ea 100644
--- a/aten/src/ATen/native/vulkan/VulkanOps.h
+++ b/aten/src/ATen/native/vulkan/VulkanOps.h
@@ -37,7 +37,7 @@
     VulkanTensor& output,
     const VulkanTensor& input,
     const float* weight,
-    const c10::optional<float*> bias,
+    const c10::optional<const float*> bias,
     const Conv2DParams params,
     c10::optional<float> output_min = c10::nullopt,
     c10::optional<float> output_max = c10::nullopt);
@@ -46,7 +46,7 @@
     VulkanTensor& output,
     const VulkanTensor& input,
     const VulkanTensor& weight_prepacked,
-    const c10::optional<float*> bias,
+    const c10::optional<const float*> bias,
     const Conv2DParams params,
     c10::optional<float> output_min = c10::nullopt,
     c10::optional<float> output_max = c10::nullopt);
diff --git a/aten/src/ATen/test/vulkan_test.cpp b/aten/src/ATen/test/vulkan_test.cpp
index da6ece7..9a6ed36 100644
--- a/aten/src/ATen/test/vulkan_test.cpp
+++ b/aten/src/ATen/test/vulkan_test.cpp
@@ -496,7 +496,7 @@
   ASSERT_TRUE(no_prepack_check);
 
   auto prepack = callOpByName(
-      "vulkan::conv2d_clamp_prepack",
+      "vulkan_prepack::conv2d_clamp_prepack",
       "",
       t_w,
       t_b,
@@ -507,7 +507,7 @@
       output_min,
       output_max);
   auto tv_out_prepack_ivalues =
-      callOpByName("vulkan::conv2d_clamp_run", "", tv_in, prepack[0]);
+      callOpByName("vulkan_prepack::conv2d_clamp_run", "", tv_in, prepack[0]);
   auto tv_out_prepack = tv_out_prepack_ivalues[0].toTensor();
   auto t_out_prepack = tv_out_prepack.cpu();
   const auto prepack_check = almostEqual(t_out_prepack, t_out_expected);
diff --git a/binaries/CMakeLists.txt b/binaries/CMakeLists.txt
index b7a7f95..075bc05 100644
--- a/binaries/CMakeLists.txt
+++ b/binaries/CMakeLists.txt
@@ -103,3 +103,4 @@
 caffe2_binary_target("tutorial_blob.cc")
 
 caffe2_binary_target("dump_operator_names.cc")
+caffe2_binary_target("optimize_for_mobile.cc")
diff --git a/binaries/optimize_for_mobile.cc b/binaries/optimize_for_mobile.cc
index 4d91448..94293ba 100644
--- a/binaries/optimize_for_mobile.cc
+++ b/binaries/optimize_for_mobile.cc
@@ -17,6 +17,7 @@
 #include <string>
 
 #include "torch/csrc/jit/api/module.h"
+#include "torch/csrc/jit/passes/vulkan_rewrite.h"
 #include "torch/csrc/jit/passes/xnnpack_rewrite.h"
 #include "torch/csrc/jit/serialization/import.h"
 
@@ -29,6 +30,7 @@
     save_for_mobile,
     false,
     "Save the model with bytecode format compatible with lite inteprter.");
+C10_DEFINE_bool(vulkan, false, "Vulkan optimize_for_mobile");
 
 int main(int argc, char** argv) {
   c10::SetUsageMessage(
@@ -52,7 +54,10 @@
   }
 
   auto module = torch::jit::load(FLAGS_model);
-  auto optimized_module = torch::jit::optimizeForMobile(module);
+
+  auto optimized_module = FLAGS_vulkan
+      ? torch::jit::vulkanOptimizeForMobile(module)
+      : torch::jit::optimizeForMobile(module);
 
   if (FLAGS_save_for_mobile) {
     optimized_module._save_for_mobile(output_model_name);
diff --git a/test/run_test.py b/test/run_test.py
index 36d4852..268c92b 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -48,6 +48,7 @@
     'test_optim',
     'test_mobile_optimizer',
     'test_xnnpack_integration',
+    'test_vulkan',
     'test_quantization',
     'test_sparse',
     'test_serialization',
diff --git a/test/test_vulkan.py b/test/test_vulkan.py
new file mode 100644
index 0000000..602feca
--- /dev/null
+++ b/test/test_vulkan.py
@@ -0,0 +1,162 @@
+import unittest
+import torch
+from torch.nn import functional as F
+
+from torch.testing._internal.common_utils import TestCase, run_tests
+from torch.testing import FileCheck
+import io
+
+@unittest.skipUnless(torch.is_vulkan_available(),
+                     "Vulkan backend must be available for these tests.")
+class TestVulkanRewritePass(TestCase):
+    @staticmethod
+    def validate_transformed_module(
+            # To please flake
+            self,
+            pattern_count_map,
+            data_shape,
+            prepack_removal=False,
+            fuse_clamping_ops=False):
+        module_instance = self
+        scripted_model = torch.jit.script(module_instance)
+        scripted_model.eval()
+        input_data = torch.normal(1, 20, size=data_shape)
+        ref_result = scripted_model(input_data)
+        torch._C._jit_pass_vulkan_insert_prepacked_ops(scripted_model._c)
+        if fuse_clamping_ops or prepack_removal:
+            scripted_model._c = torch._C._freeze_module(scripted_model._c)
+        if fuse_clamping_ops:
+            torch._C._jit_pass_vulkan_fuse_clamp_w_prepacked_conv(scripted_model._c)
+        if prepack_removal:
+            torch._C._jit_pass_vulkan_fold_prepacking_ops(scripted_model._c)
+
+        buffer = io.BytesIO()
+        torch.jit.save(scripted_model, buffer)
+        buffer.seek(0)
+        deserialized_scripted_model = torch.jit.load(buffer)
+        for pattern, v in pattern_count_map.items():
+            if (v == 0):
+                FileCheck().check(pattern).run(deserialized_scripted_model.graph)
+            elif (v == -1):
+                FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
+            else:
+                FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
+
+    def test_conv(self):
+        # Conv params
+        batch_size = 2
+        input_channels_per_group = 6
+        height = 16
+        width = 16
+        output_channels_per_group = 6
+        groups = 4
+        kernel_h = kernel_w = 3
+        stride_h = stride_w = 1
+        pad_h = pad_w = 1
+        dilation = 1
+        input_channels = input_channels_per_group * groups
+        output_channels = output_channels_per_group * groups
+        kernels = (kernel_h, kernel_w)
+        strides = (stride_h, stride_w)
+        paddings = (pad_h, pad_w)
+        dilations = (dilation, dilation)
+        conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
+        conv_bias_shape = (output_channels)
+
+        class Conv2D(torch.nn.Module):
+            def __init__(self):
+                super(Conv2D, self).__init__()
+                self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
+                self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
+                self.strides = strides
+                self.paddings = paddings
+                self.dilations = dilations
+                self.groups = groups
+
+            def forward(self, x):
+                return F.conv2d(x, self.weight, self.bias,
+                                self.strides, self.paddings, self.dilations, self.groups)
+
+        data_shape = (batch_size, input_channels, height, width)
+        pattern_count_map = {"Tensor = aten::conv2d": -1,
+                             "vulkan_prepack::conv2d_clamp_prepack": 1,
+                             "vulkan_prepack::conv2d_clamp_run": 1}
+        TestVulkanRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
+
+        class Conv2DRelu(torch.nn.Module):
+            def __init__(self):
+                super(Conv2DRelu, self).__init__()
+                self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
+                self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
+                self.strides = strides
+                self.paddings = paddings
+                self.dilations = dilations
+                self.groups = groups
+
+            def forward(self, x):
+                o = F.conv2d(x, self.weight, self.bias,
+                             self.strides, self.paddings, self.dilations, self.groups)
+                o = F.relu(o)
+                return o
+
+        data_shape = (batch_size, input_channels, height, width)
+        pattern_count_map = {"Tensor = aten::conv2d": -1,
+                             "vulkan_prepack::conv2d_clamp_prepack": 1,
+                             "vulkan_prepack::conv2d_clamp_run": 1}
+        TestVulkanRewritePass.validate_transformed_module(
+            Conv2DRelu(), pattern_count_map, data_shape)
+
+        pattern_count_map["aten::relu"] = 1
+        pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
+        TestVulkanRewritePass.validate_transformed_module(
+            Conv2DRelu(),
+            pattern_count_map,
+            data_shape,
+            prepack_removal=True)
+        pattern_count_map["aten::relu"] = -1
+        TestVulkanRewritePass.validate_transformed_module(
+            Conv2DRelu(),
+            pattern_count_map,
+            data_shape,
+            prepack_removal=True,
+            fuse_clamping_ops=True)
+
+
+        class Conv2DHardtanh(torch.nn.Module):
+            def __init__(self):
+                super(Conv2DHardtanh, self).__init__()
+                self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
+                self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
+                self.strides = strides
+                self.paddings = paddings
+                self.dilations = dilations
+                self.groups = groups
+
+            def forward(self, x):
+                o = F.conv2d(x, self.weight, self.bias,
+                             self.strides, self.paddings, self.dilations, self.groups)
+                o = F.hardtanh(o)
+                return o
+
+        data_shape = (batch_size, input_channels, height, width)
+        pattern_count_map = {"Tensor = aten::conv2d": -1,
+                             "vulkan_prepack::conv2d_clamp_prepack": 1,
+                             "vulkan_prepack::conv2d_clamp_run": 1}
+        TestVulkanRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape)
+        pattern_count_map["aten::hardtanh"] = 1
+        pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
+        TestVulkanRewritePass.validate_transformed_module(
+            Conv2DHardtanh(),
+            pattern_count_map,
+            data_shape,
+            prepack_removal=True)
+        pattern_count_map["aten::hardtanh"] = -1
+        TestVulkanRewritePass.validate_transformed_module(
+            Conv2DRelu(),
+            pattern_count_map,
+            data_shape,
+            prepack_removal=True,
+            fuse_clamping_ops=True)
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index 31e44de..3ab3116 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -177,6 +177,7 @@
     "torch/csrc/jit/passes/utils/memory_dag.cpp",
     "torch/csrc/jit/passes/utils/subgraph_utils.cpp",
     "torch/csrc/jit/passes/xnnpack_rewrite.cpp",
+    "torch/csrc/jit/passes/vulkan_rewrite.cpp",
     "torch/csrc/jit/passes/quantization/helper.cpp",
     "torch/csrc/jit/passes/quantization/quantization_type.cpp",
     "torch/csrc/jit/passes/quantization/insert_observers.cpp",
diff --git a/torch/csrc/jit/passes/graph_rewrite_helper.cpp b/torch/csrc/jit/passes/graph_rewrite_helper.cpp
index a52b09e..4c53d57 100644
--- a/torch/csrc/jit/passes/graph_rewrite_helper.cpp
+++ b/torch/csrc/jit/passes/graph_rewrite_helper.cpp
@@ -152,6 +152,41 @@
   rewriter_conv3d.runOnGraph(graph, filter_conv3d);
 }
 
+bool isClampFusable(
+    const Match& match,
+    const std::unordered_map<std::string, Value*>& vmap) {
+  const auto& match_vmap = match.values_map;
+  TORCH_CHECK(
+      vmap.find("dummy_min_max") != vmap.end(),
+      "Expected to find dummy_min_max Value in the subgraph to be replaced.");
+  auto dummy_min_max =
+      graph_rewrite_helper::getIValue("dummy_min_max", match_vmap, vmap);
+
+  auto is_fusable = !dummy_min_max || dummy_min_max.value().isNone();
+
+  // Also check if the output_min and output_max values are actually constant.
+  // If hardtanh's min/max Value's are not actually constants, we will end up
+  // rerouting those values to prepack op. And if they are not constants
+  // we will not be able to remove prepacking ops.
+  if (vmap.find("output_min") != vmap.end()) {
+    // aten::relu pattern does not have output_min/output_max.
+    // aten::hardtanh/_ does.
+    TORCH_CHECK(
+        vmap.find("output_max") != vmap.end(),
+        "Expected to find output_max as well given "
+        "output_min exist in pattern graph.");
+    // If output_min/max are not constant, we get c10::nullopt.
+    auto output_min =
+        graph_rewrite_helper::getIValue("output_min", match_vmap, vmap);
+    auto output_max =
+        graph_rewrite_helper::getIValue("output_max", match_vmap, vmap);
+    is_fusable =
+        is_fusable && (output_min.has_value() && output_max.has_value());
+  }
+
+  return is_fusable;
+}
+
 } // namespace graph_rewrite_helper
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/passes/graph_rewrite_helper.h b/torch/csrc/jit/passes/graph_rewrite_helper.h
index 5d64228..00ce273 100644
--- a/torch/csrc/jit/passes/graph_rewrite_helper.h
+++ b/torch/csrc/jit/passes/graph_rewrite_helper.h
@@ -19,6 +19,10 @@
     const std::unordered_map<std::string, Value*>& vmap);
 void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph);
 
+bool isClampFusable(
+    const Match& match,
+    const std::unordered_map<std::string, Value*>& vmap);
+
 using MatchFilter = std::function<
     bool(const Match&, const std::unordered_map<std::string, Value*>&)>;
 
diff --git a/torch/csrc/jit/passes/vulkan_rewrite.cpp b/torch/csrc/jit/passes/vulkan_rewrite.cpp
new file mode 100644
index 0000000..6d0dc6c
--- /dev/null
+++ b/torch/csrc/jit/passes/vulkan_rewrite.cpp
@@ -0,0 +1,205 @@
+#include <ATen/core/jit_type.h>
+#ifdef USE_VULKAN
+#include <ATen/native/vulkan/VulkanOpContext.h>
+#endif
+
+#include <torch/csrc/jit/ir/ir.h>
+#include <torch/csrc/jit/ir/subgraph_matcher.h>
+#include <torch/csrc/jit/passes/constant_pooling.h>
+#include <torch/csrc/jit/passes/fold_conv_bn.h>
+#include <torch/csrc/jit/passes/freeze_module.h>
+#include <torch/csrc/jit/passes/fuse_linear.h>
+#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
+#include <torch/csrc/jit/passes/prepack_folding.h>
+#include <torch/csrc/jit/passes/remove_dropout.h>
+#include <torch/csrc/jit/passes/subgraph_rewrite.h>
+#include <torch/csrc/jit/passes/vulkan_rewrite.h>
+
+namespace torch {
+namespace jit {
+
+#ifdef USE_VULKAN
+
+namespace {
+
+void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
+  graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
+
+  std::string conv_2d_pattern = R"(
+    graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
+        %r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
+        return (%r) )";
+
+  std::string prepacked_ops_conv2d_pattern = R"(
+    graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
+        %output_min_max : None = prim::Constant()
+        %packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
+            %weight, %bias, %stride, %padding, %dilation, %groups,
+            %output_min_max, %output_min_max)
+        %r = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
+        return (%r) )";
+
+  SubgraphRewriter rewriter;
+  rewriter.RegisterRewritePattern(
+      conv_2d_pattern, prepacked_ops_conv2d_pattern);
+  rewriter.runOnGraph(graph);
+}
+
+void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
+  SubgraphRewriter rewriter;
+
+  std::string conv2d_prepack_run_hardtanh_fused = R"(
+    graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+          %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
+        %packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dOpContext = vulkan_prepack::conv2d_clamp_prepack(
+            %weight, %bias, %stride, %padding, %dilation, %groups,
+            %output_min, %output_max)
+        %r = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
+        return (%r) )";
+
+  std::string conv2d_prepack_run_hardtanh = R"(
+    graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+          %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
+        %packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
+            %weight, %bias, %stride, %padding, %dilation, %groups,
+            %dummy_min_max, %dummy_min_max)
+        %conv2d_res = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
+        %r = aten::hardtanh(%conv2d_res, %output_min, %output_max)
+        return (%r) )";
+
+  rewriter.RegisterRewritePattern(
+      conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused);
+
+  std::string conv2d_prepack_run_hardtanh_inplace = R"(
+    graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+          %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
+        %packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
+            %weight, %bias, %stride, %padding, %dilation, %groups,
+            %dummy_min_max, %dummy_min_max)
+        %conv2d_res = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
+        %r = aten::hardtanh_(%conv2d_res, %output_min, %output_max)
+        return (%r) )";
+
+  rewriter.RegisterRewritePattern(
+      conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
+
+  rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
+}
+
+void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
+  SubgraphRewriter rewriter;
+
+  std::string conv2d_prepack_run_relu_fused = R"(
+    graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+          %dilation:int[], %groups:int, %dummy_min_max):
+        %output_min: float = prim::Constant[value=0.0]()
+        %output_max: None = prim::Constant()
+        %packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dOpContext = vulkan_prepack::conv2d_clamp_prepack(
+            %weight, %bias, %stride, %padding, %dilation, %groups,
+            %output_min, %output_max)
+        %r = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
+        return (%r) )";
+
+  std::string conv2d_prepack_run_relu = R"(
+    graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+          %dilation:int[], %groups:int, %dummy_min_max):
+        %packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
+            %weight, %bias, %stride, %padding, %dilation, %groups,
+            %dummy_min_max, %dummy_min_max)
+        %conv2d_res = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
+        %r = aten::relu(%conv2d_res)
+        return (%r) )";
+
+  rewriter.RegisterRewritePattern(
+      conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused);
+
+  std::string conv2d_prepack_run_relu_inplace = R"(
+    graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+          %dilation:int[], %groups:int, %dummy_min_max):
+        %packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
+            %weight, %bias, %stride, %padding, %dilation, %groups,
+            %dummy_min_max, %dummy_min_max)
+        %conv2d_res = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
+        %r = aten::relu_(%conv2d_res)
+        return (%r) )";
+
+  rewriter.RegisterRewritePattern(
+      conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
+  rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
+}
+
+} // namespace
+
+void vulkanInsertPrePackedOps(std::shared_ptr<Graph>& graph) {
+  insertPrePackedConv2dOp(graph);
+}
+
+void vulkanInsertPrePackedOps(script::Module& module) {
+  for (auto& method : module.get_methods()) {
+    auto graph = method.graph();
+    vulkanInsertPrePackedOps(graph);
+  }
+  for (script::Module m : module.children()) {
+    vulkanInsertPrePackedOps(m);
+  }
+}
+
+void vulkanFusePrePackedConvWithClamp(script::Module& module) {
+  auto graph = module.get_method("forward").graph();
+  fuseReluWithPackedOps(graph);
+  fuseHardtanhWithPackedOps(graph);
+}
+
+void vulkanFoldPrePackingOps(script::Module& m) {
+  PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool {
+    return (
+        n->kind() ==
+        Symbol::fromQualString("vulkan_prepack::conv2d_clamp_prepack"));
+  };
+  PrePackingOpsFolder(m, filter_fn, "prepack_folding");
+}
+
+script::Module vulkanOptimizeForMobile(const script::Module& m) {
+  auto cloned_module = m.clone();
+  cloned_module.eval();
+  cloned_module = FoldConvBatchNorm2d(cloned_module);
+  vulkanInsertPrePackedOps(cloned_module);
+  cloned_module = freeze_module(cloned_module);
+  vulkanFusePrePackedConvWithClamp(cloned_module);
+  vulkanFoldPrePackingOps(cloned_module);
+  removeDropout(cloned_module);
+  return cloned_module;
+}
+
+#else
+
+void vulkanInsertPrePackedOps(std::shared_ptr<Graph>& graph) {
+  TORCH_INTERNAL_ASSERT(
+      "Vulkan is not enabled. Please build with USE_VULKAN=1");
+}
+
+void vulkanInsertPrePackedOps(script::Module& module) {
+  TORCH_INTERNAL_ASSERT(
+      "Vulkan is not enabled. Please build with USE_VULKAN=1");
+}
+
+void vulkanFusePrePackedConvWithClamp(script::Module& module) {
+  TORCH_INTERNAL_ASSERT(
+      "Vulkan is not enabled. Please build with USE_VULKAN=1");
+}
+
+void vulkanFoldPrePackingOps(script::Module& m) {
+  TORCH_INTERNAL_ASSERT(
+      "Vulkan is not enabled. Please build with USE_VULKAN=1");
+}
+
+script::Module vulkanOptimizeForMobile(const script::Module& module) {
+  TORCH_INTERNAL_ASSERT(
+      "Mobile optimizaiton only available with Vulkan at the moment. "
+      "Vulkan is not enabled. Please build with USE_VULKAN=1");
+  return module;
+}
+
+#endif
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/passes/vulkan_rewrite.h b/torch/csrc/jit/passes/vulkan_rewrite.h
new file mode 100644
index 0000000..31b9f1b
--- /dev/null
+++ b/torch/csrc/jit/passes/vulkan_rewrite.h
@@ -0,0 +1,14 @@
+#pragma once
+
+#include <torch/csrc/jit/api/module.h>
+#include <torch/csrc/jit/ir/ir.h>
+
+namespace torch {
+namespace jit {
+TORCH_API void vulkanInsertPrePackedOps(std::shared_ptr<Graph>& graph);
+TORCH_API void vulkanInsertPrePackedOps(script::Module& module);
+TORCH_API void vulkanFusePrePackedConvWithClamp(script::Module& module);
+TORCH_API void vulkanFoldPrePackingOps(script::Module& module);
+TORCH_API script::Module vulkanOptimizeForMobile(const script::Module& module);
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.cpp b/torch/csrc/jit/passes/xnnpack_rewrite.cpp
index 8d8888e..153b4e2 100644
--- a/torch/csrc/jit/passes/xnnpack_rewrite.cpp
+++ b/torch/csrc/jit/passes/xnnpack_rewrite.cpp
@@ -93,41 +93,6 @@
   rewriter.runOnGraph(graph);
 }
 
-bool isClampFusable(
-    const Match& match,
-    const std::unordered_map<std::string, Value*>& vmap) {
-  const auto& match_vmap = match.values_map;
-  TORCH_CHECK(
-      vmap.find("dummy_min_max") != vmap.end(),
-      "Expected to find dummy_min_max Value in the subgraph to be replaced.");
-  auto dummy_min_max =
-      graph_rewrite_helper::getIValue("dummy_min_max", match_vmap, vmap);
-
-  auto is_fusable = !dummy_min_max || dummy_min_max.value().isNone();
-
-  // Also check if the output_min and output_max values are actually constant.
-  // If hardtanh's min/max Value's are not actually constants, we will end up
-  // rerouting those values to prepack op. And if they are not constants
-  // we will not be able to remove prepacking ops.
-  if (vmap.find("output_min") != vmap.end()) {
-    // aten::relu pattern does not have output_min/output_max.
-    // aten::hardtanh/_ does.
-    TORCH_CHECK(
-        vmap.find("output_max") != vmap.end(),
-        "Expected to find output_max as well given "
-        "output_min exist in pattern graph.");
-    // If output_min/max are not constant, we get c10::nullopt.
-    auto output_min =
-        graph_rewrite_helper::getIValue("output_min", match_vmap, vmap);
-    auto output_max =
-        graph_rewrite_helper::getIValue("output_max", match_vmap, vmap);
-    is_fusable =
-        is_fusable && (output_min.has_value() && output_max.has_value());
-  }
-
-  return is_fusable;
-}
-
 void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
   SubgraphRewriter rewriter;
 
@@ -194,7 +159,7 @@
   rewriter.RegisterRewritePattern(
       conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
 
-  rewriter.runOnGraph(graph, isClampFusable);
+  rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
 }
 
 void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
@@ -266,7 +231,7 @@
       linear_prepack_run_relu_inplace, linear_prepack_run_relu_fused);
   rewriter.RegisterRewritePattern(
       conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
-  rewriter.runOnGraph(graph, isClampFusable);
+  rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
 }
 
 void runCanonicalOptimizations(script::Module& module) {
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 14aace4..910d5cd 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -54,6 +54,7 @@
 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
 #include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
+#include <torch/csrc/jit/passes/vulkan_rewrite.h>
 #include <torch/csrc/jit/passes/xnnpack_rewrite.h>
 #include <torch/csrc/jit/python/pybind_utils.h>
 #include <torch/csrc/jit/python/python_arg_flatten.h>
@@ -542,6 +543,31 @@
             return optimizeForMobile(module, optimization_blacklist);
           })
       .def(
+          "_jit_pass_vulkan_insert_prepacked_ops",
+          [](std::shared_ptr<Graph>& graph) {
+            return vulkanInsertPrePackedOps(graph);
+          })
+      .def(
+          "_jit_pass_vulkan_insert_prepacked_ops",
+          [](script::Module& module) {
+            return vulkanInsertPrePackedOps(module);
+          })
+      .def(
+          "_jit_pass_vulkan_fuse_clamp_w_prepacked_conv",
+          [](script::Module& module) {
+            return vulkanFusePrePackedConvWithClamp(module);
+          })
+      .def(
+          "_jit_pass_vulkan_fold_prepacking_ops",
+          [](script::Module& module) {
+            return vulkanFoldPrePackingOps(module);
+          })
+      .def(
+          "_jit_pass_vulkan_optimize_for_mobile",
+          [](script::Module& module) {
+            return vulkanOptimizeForMobile(module);
+          })
+      .def(
           "_jit_pass_onnx_unpack_quantized_weights",
           [](std::shared_ptr<Graph>& graph,
              std::map<std::string, IValue>& paramsDict) {