[vulkan] jit passes for vulkan conv2 prepack and fuse with clamp (#39282)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39282
Test Plan: Imported from OSS
Differential Revision: D21962424
Pulled By: IvanKobzarev
fbshipit-source-id: 2d20e827d2c3836b7e6b443293377c68dc1ffa5a
diff --git a/aten/src/ATen/native/vulkan/VulkanAten.cpp b/aten/src/ATen/native/vulkan/VulkanAten.cpp
index bd71f82..a209e71 100644
--- a/aten/src/ATen/native/vulkan/VulkanAten.cpp
+++ b/aten/src/ATen/native/vulkan/VulkanAten.cpp
@@ -179,7 +179,7 @@
voutput,
vinput,
weight.data_ptr<float>(),
- bias.defined() ? c10::make_optional<float*>(bias.data_ptr<float>())
+ bias.defined() ? c10::make_optional<const float*>(bias.data_ptr<float>())
: c10::nullopt,
params);
return new_with_vtensor_vulkan(std::move(voutput), input.options());
@@ -242,7 +242,8 @@
voutput,
vinput,
vweight,
- hasBias ? c10::make_optional((*bias).data_ptr<float>()) : c10::nullopt,
+ hasBias ? c10::make_optional<const float*>((*bias).data_ptr<float>())
+ : c10::nullopt,
params,
output_min,
output_max);
diff --git a/aten/src/ATen/native/vulkan/VulkanConvolution.cpp b/aten/src/ATen/native/vulkan/VulkanConvolution.cpp
index a3128ff..ef8c18a 100644
--- a/aten/src/ATen/native/vulkan/VulkanConvolution.cpp
+++ b/aten/src/ATen/native/vulkan/VulkanConvolution.cpp
@@ -66,14 +66,13 @@
const auto stride_expanded = expand_param_if_needed(stride, "stride", 2);
const auto dilation_expanded =
expand_param_if_needed(dilation, "dilation", 2);
- const Tensor weight_nchw = weight.contiguous();
+ Tensor weight_nchw = weight.contiguous();
+ auto ws = weight_nchw.sizes();
return ContextConv2D{
- at::native::vulkan_convolution_prepack_weights(weight),
+ groups == 1 ? at::native::vulkan_convolution_prepack_weights(weight_nchw)
+ : weight_nchw.vulkan(),
bias.has_value() ? c10::make_optional((*bias).vulkan()) : c10::nullopt,
- {weight_nchw.sizes()[0],
- weight_nchw.sizes()[1],
- weight_nchw.sizes()[2],
- weight_nchw.sizes()[3]},
+ {{ws[0], ws[1], ws[2], ws[3]}},
{padding_expanded[0], padding_expanded[1]},
{stride_expanded[0], stride_expanded[1]},
{dilation_expanded[0], dilation_expanded[1]},
diff --git a/aten/src/ATen/native/vulkan/VulkanOps.cpp b/aten/src/ATen/native/vulkan/VulkanOps.cpp
index 2dda64c..c69a70b 100644
--- a/aten/src/ATen/native/vulkan/VulkanOps.cpp
+++ b/aten/src/ATen/native/vulkan/VulkanOps.cpp
@@ -176,7 +176,7 @@
}
VBuffer bufferFromOptionalHostData(
- c10::optional<float*> data,
+ c10::optional<const float*> data,
const uint32_t size) {
const auto sizeAligned =
ROUND_UP(size, context().limits().minStorageBufferOffsetAlignment);
@@ -202,17 +202,15 @@
void conv2d_depthwise(
VulkanTensor& output,
const VulkanTensor& input,
- const float* weight,
- const c10::optional<float*> bias,
- const Conv2DParams params,
+ const VulkanTensor& weight,
+ const VBuffer& biasBuffer,
+ const Conv2DParams& params,
c10::optional<float> output_min,
c10::optional<float> output_max) {
TORCH_INTERNAL_ASSERT(params.G == params.C);
auto osizes = output.sizes();
TORCH_INTERNAL_ASSERT(osizes[2] == params.OH);
TORCH_INTERNAL_ASSERT(osizes[3] == params.OW);
- auto biasBuffer =
- bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC));
struct ConstBlock {
int32_t padding[2];
int32_t kernelSize[2];
@@ -234,9 +232,6 @@
output_max ? *output_max : std::numeric_limits<float>::infinity()};
VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb));
- VulkanTensor kernel{{params.OC, params.KH, params.KW}};
- kernel.set_data_from_host(weight);
-
VkDescriptorSetLayout descriptorSetLayout{};
VkDescriptorPool descriptorPool{};
VkDescriptorSet descriptorSet{};
@@ -256,7 +251,7 @@
output.image()->bindStorageImage(descriptorSet, 0);
input.image()->bindShaderRead(descriptorSet, 1);
- kernel.image()->bindShaderRead(descriptorSet, 2);
+ weight.image()->bindShaderRead(descriptorSet, 2);
biasBuffer.bind(descriptorSet, 3);
constBuffer.bind(descriptorSet, 4);
@@ -269,7 +264,7 @@
auto commandBuffer = computeUnit.commandBuffer();
output.image()->addImageMemoryBarrierToGeneral(commandBuffer);
input.image()->addImageMemoryBarrierToShaderRead(commandBuffer);
- kernel.image()->addImageMemoryBarrierToShaderRead(commandBuffer);
+ weight.image()->addImageMemoryBarrierToShaderRead(commandBuffer);
computeUnit.dispatchCommandBuffer(
params.OW, params.OH, params.OC_4, workGroupSize);
computeUnit.endCommandBuffer();
@@ -279,6 +274,44 @@
vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr);
}
+void conv2d_depthwise(
+ VulkanTensor& output,
+ const VulkanTensor& input,
+ const VulkanTensor& weight,
+ const c10::optional<const float*> bias,
+ const Conv2DParams params,
+ c10::optional<float> output_min,
+ c10::optional<float> output_max) {
+ conv2d_depthwise(
+ output,
+ input,
+ weight,
+ bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC)),
+ params,
+ output_min,
+ output_max);
+}
+
+void conv2d_depthwise(
+ VulkanTensor& output,
+ const VulkanTensor& input,
+ const float* weight,
+ const c10::optional<const float*> bias,
+ const Conv2DParams params,
+ c10::optional<float> output_min,
+ c10::optional<float> output_max) {
+ VulkanTensor weightTensor{{params.OC, params.KH, params.KW}};
+ weightTensor.set_data_from_host(weight);
+ conv2d_depthwise(
+ output,
+ input,
+ weightTensor,
+ bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC)),
+ params,
+ output_min,
+ output_max);
+}
+
ImageSizes conv2d_prepack_weights_image_sizes(
int64_t OC,
int64_t C,
@@ -463,7 +496,7 @@
VulkanTensor& output,
const VulkanTensor& input,
const VImage& kernelImage,
- const c10::optional<float*> bias,
+ const c10::optional<const float*> bias,
const Conv2DParams& params,
c10::optional<float> output_min,
c10::optional<float> output_max) {
@@ -483,10 +516,22 @@
VulkanTensor& output,
const VulkanTensor& input,
const VulkanTensor& weight_prepacked,
- c10::optional<float*> bias,
+ c10::optional<const float*> bias,
const Conv2DParams params,
c10::optional<float> output_min,
c10::optional<float> output_max) {
+ if (params.G > 1) {
+ conv2d_depthwise(
+ output,
+ input,
+ weight_prepacked,
+ bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC)),
+ params,
+ output_min,
+ output_max);
+ return;
+ }
+
conv2d(
output,
input,
@@ -505,6 +550,18 @@
const Conv2DParams params,
c10::optional<float> output_min,
c10::optional<float> output_max) {
+ if (params.G > 1) {
+ conv2d_depthwise(
+ output,
+ input,
+ weight_prepacked,
+ *(bias.buffer()),
+ params,
+ output_min,
+ output_max);
+ return;
+ }
+
conv2d(
output,
input,
@@ -519,7 +576,7 @@
VulkanTensor& output,
const VulkanTensor& input,
const float* weight,
- const c10::optional<float*> bias,
+ const c10::optional<const float*> bias,
const Conv2DParams params,
c10::optional<float> output_min,
c10::optional<float> output_max) {
diff --git a/aten/src/ATen/native/vulkan/VulkanOps.h b/aten/src/ATen/native/vulkan/VulkanOps.h
index 969bd85..64b32ea 100644
--- a/aten/src/ATen/native/vulkan/VulkanOps.h
+++ b/aten/src/ATen/native/vulkan/VulkanOps.h
@@ -37,7 +37,7 @@
VulkanTensor& output,
const VulkanTensor& input,
const float* weight,
- const c10::optional<float*> bias,
+ const c10::optional<const float*> bias,
const Conv2DParams params,
c10::optional<float> output_min = c10::nullopt,
c10::optional<float> output_max = c10::nullopt);
@@ -46,7 +46,7 @@
VulkanTensor& output,
const VulkanTensor& input,
const VulkanTensor& weight_prepacked,
- const c10::optional<float*> bias,
+ const c10::optional<const float*> bias,
const Conv2DParams params,
c10::optional<float> output_min = c10::nullopt,
c10::optional<float> output_max = c10::nullopt);
diff --git a/aten/src/ATen/test/vulkan_test.cpp b/aten/src/ATen/test/vulkan_test.cpp
index da6ece7..9a6ed36 100644
--- a/aten/src/ATen/test/vulkan_test.cpp
+++ b/aten/src/ATen/test/vulkan_test.cpp
@@ -496,7 +496,7 @@
ASSERT_TRUE(no_prepack_check);
auto prepack = callOpByName(
- "vulkan::conv2d_clamp_prepack",
+ "vulkan_prepack::conv2d_clamp_prepack",
"",
t_w,
t_b,
@@ -507,7 +507,7 @@
output_min,
output_max);
auto tv_out_prepack_ivalues =
- callOpByName("vulkan::conv2d_clamp_run", "", tv_in, prepack[0]);
+ callOpByName("vulkan_prepack::conv2d_clamp_run", "", tv_in, prepack[0]);
auto tv_out_prepack = tv_out_prepack_ivalues[0].toTensor();
auto t_out_prepack = tv_out_prepack.cpu();
const auto prepack_check = almostEqual(t_out_prepack, t_out_expected);
diff --git a/binaries/CMakeLists.txt b/binaries/CMakeLists.txt
index b7a7f95..075bc05 100644
--- a/binaries/CMakeLists.txt
+++ b/binaries/CMakeLists.txt
@@ -103,3 +103,4 @@
caffe2_binary_target("tutorial_blob.cc")
caffe2_binary_target("dump_operator_names.cc")
+caffe2_binary_target("optimize_for_mobile.cc")
diff --git a/binaries/optimize_for_mobile.cc b/binaries/optimize_for_mobile.cc
index 4d91448..94293ba 100644
--- a/binaries/optimize_for_mobile.cc
+++ b/binaries/optimize_for_mobile.cc
@@ -17,6 +17,7 @@
#include <string>
#include "torch/csrc/jit/api/module.h"
+#include "torch/csrc/jit/passes/vulkan_rewrite.h"
#include "torch/csrc/jit/passes/xnnpack_rewrite.h"
#include "torch/csrc/jit/serialization/import.h"
@@ -29,6 +30,7 @@
save_for_mobile,
false,
"Save the model with bytecode format compatible with lite inteprter.");
+C10_DEFINE_bool(vulkan, false, "Vulkan optimize_for_mobile");
int main(int argc, char** argv) {
c10::SetUsageMessage(
@@ -52,7 +54,10 @@
}
auto module = torch::jit::load(FLAGS_model);
- auto optimized_module = torch::jit::optimizeForMobile(module);
+
+ auto optimized_module = FLAGS_vulkan
+ ? torch::jit::vulkanOptimizeForMobile(module)
+ : torch::jit::optimizeForMobile(module);
if (FLAGS_save_for_mobile) {
optimized_module._save_for_mobile(output_model_name);
diff --git a/test/run_test.py b/test/run_test.py
index 36d4852..268c92b 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -48,6 +48,7 @@
'test_optim',
'test_mobile_optimizer',
'test_xnnpack_integration',
+ 'test_vulkan',
'test_quantization',
'test_sparse',
'test_serialization',
diff --git a/test/test_vulkan.py b/test/test_vulkan.py
new file mode 100644
index 0000000..602feca
--- /dev/null
+++ b/test/test_vulkan.py
@@ -0,0 +1,162 @@
+import unittest
+import torch
+from torch.nn import functional as F
+
+from torch.testing._internal.common_utils import TestCase, run_tests
+from torch.testing import FileCheck
+import io
+
+@unittest.skipUnless(torch.is_vulkan_available(),
+ "Vulkan backend must be available for these tests.")
+class TestVulkanRewritePass(TestCase):
+ @staticmethod
+ def validate_transformed_module(
+ # To please flake
+ self,
+ pattern_count_map,
+ data_shape,
+ prepack_removal=False,
+ fuse_clamping_ops=False):
+ module_instance = self
+ scripted_model = torch.jit.script(module_instance)
+ scripted_model.eval()
+ input_data = torch.normal(1, 20, size=data_shape)
+ ref_result = scripted_model(input_data)
+ torch._C._jit_pass_vulkan_insert_prepacked_ops(scripted_model._c)
+ if fuse_clamping_ops or prepack_removal:
+ scripted_model._c = torch._C._freeze_module(scripted_model._c)
+ if fuse_clamping_ops:
+ torch._C._jit_pass_vulkan_fuse_clamp_w_prepacked_conv(scripted_model._c)
+ if prepack_removal:
+ torch._C._jit_pass_vulkan_fold_prepacking_ops(scripted_model._c)
+
+ buffer = io.BytesIO()
+ torch.jit.save(scripted_model, buffer)
+ buffer.seek(0)
+ deserialized_scripted_model = torch.jit.load(buffer)
+ for pattern, v in pattern_count_map.items():
+ if (v == 0):
+ FileCheck().check(pattern).run(deserialized_scripted_model.graph)
+ elif (v == -1):
+ FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
+ else:
+ FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
+
+ def test_conv(self):
+ # Conv params
+ batch_size = 2
+ input_channels_per_group = 6
+ height = 16
+ width = 16
+ output_channels_per_group = 6
+ groups = 4
+ kernel_h = kernel_w = 3
+ stride_h = stride_w = 1
+ pad_h = pad_w = 1
+ dilation = 1
+ input_channels = input_channels_per_group * groups
+ output_channels = output_channels_per_group * groups
+ kernels = (kernel_h, kernel_w)
+ strides = (stride_h, stride_w)
+ paddings = (pad_h, pad_w)
+ dilations = (dilation, dilation)
+ conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
+ conv_bias_shape = (output_channels)
+
+ class Conv2D(torch.nn.Module):
+ def __init__(self):
+ super(Conv2D, self).__init__()
+ self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
+ self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
+ self.strides = strides
+ self.paddings = paddings
+ self.dilations = dilations
+ self.groups = groups
+
+ def forward(self, x):
+ return F.conv2d(x, self.weight, self.bias,
+ self.strides, self.paddings, self.dilations, self.groups)
+
+ data_shape = (batch_size, input_channels, height, width)
+ pattern_count_map = {"Tensor = aten::conv2d": -1,
+ "vulkan_prepack::conv2d_clamp_prepack": 1,
+ "vulkan_prepack::conv2d_clamp_run": 1}
+ TestVulkanRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
+
+ class Conv2DRelu(torch.nn.Module):
+ def __init__(self):
+ super(Conv2DRelu, self).__init__()
+ self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
+ self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
+ self.strides = strides
+ self.paddings = paddings
+ self.dilations = dilations
+ self.groups = groups
+
+ def forward(self, x):
+ o = F.conv2d(x, self.weight, self.bias,
+ self.strides, self.paddings, self.dilations, self.groups)
+ o = F.relu(o)
+ return o
+
+ data_shape = (batch_size, input_channels, height, width)
+ pattern_count_map = {"Tensor = aten::conv2d": -1,
+ "vulkan_prepack::conv2d_clamp_prepack": 1,
+ "vulkan_prepack::conv2d_clamp_run": 1}
+ TestVulkanRewritePass.validate_transformed_module(
+ Conv2DRelu(), pattern_count_map, data_shape)
+
+ pattern_count_map["aten::relu"] = 1
+ pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
+ TestVulkanRewritePass.validate_transformed_module(
+ Conv2DRelu(),
+ pattern_count_map,
+ data_shape,
+ prepack_removal=True)
+ pattern_count_map["aten::relu"] = -1
+ TestVulkanRewritePass.validate_transformed_module(
+ Conv2DRelu(),
+ pattern_count_map,
+ data_shape,
+ prepack_removal=True,
+ fuse_clamping_ops=True)
+
+
+ class Conv2DHardtanh(torch.nn.Module):
+ def __init__(self):
+ super(Conv2DHardtanh, self).__init__()
+ self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
+ self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
+ self.strides = strides
+ self.paddings = paddings
+ self.dilations = dilations
+ self.groups = groups
+
+ def forward(self, x):
+ o = F.conv2d(x, self.weight, self.bias,
+ self.strides, self.paddings, self.dilations, self.groups)
+ o = F.hardtanh(o)
+ return o
+
+ data_shape = (batch_size, input_channels, height, width)
+ pattern_count_map = {"Tensor = aten::conv2d": -1,
+ "vulkan_prepack::conv2d_clamp_prepack": 1,
+ "vulkan_prepack::conv2d_clamp_run": 1}
+ TestVulkanRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape)
+ pattern_count_map["aten::hardtanh"] = 1
+ pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
+ TestVulkanRewritePass.validate_transformed_module(
+ Conv2DHardtanh(),
+ pattern_count_map,
+ data_shape,
+ prepack_removal=True)
+ pattern_count_map["aten::hardtanh"] = -1
+ TestVulkanRewritePass.validate_transformed_module(
+ Conv2DRelu(),
+ pattern_count_map,
+ data_shape,
+ prepack_removal=True,
+ fuse_clamping_ops=True)
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index 31e44de..3ab3116 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -177,6 +177,7 @@
"torch/csrc/jit/passes/utils/memory_dag.cpp",
"torch/csrc/jit/passes/utils/subgraph_utils.cpp",
"torch/csrc/jit/passes/xnnpack_rewrite.cpp",
+ "torch/csrc/jit/passes/vulkan_rewrite.cpp",
"torch/csrc/jit/passes/quantization/helper.cpp",
"torch/csrc/jit/passes/quantization/quantization_type.cpp",
"torch/csrc/jit/passes/quantization/insert_observers.cpp",
diff --git a/torch/csrc/jit/passes/graph_rewrite_helper.cpp b/torch/csrc/jit/passes/graph_rewrite_helper.cpp
index a52b09e..4c53d57 100644
--- a/torch/csrc/jit/passes/graph_rewrite_helper.cpp
+++ b/torch/csrc/jit/passes/graph_rewrite_helper.cpp
@@ -152,6 +152,41 @@
rewriter_conv3d.runOnGraph(graph, filter_conv3d);
}
+bool isClampFusable(
+ const Match& match,
+ const std::unordered_map<std::string, Value*>& vmap) {
+ const auto& match_vmap = match.values_map;
+ TORCH_CHECK(
+ vmap.find("dummy_min_max") != vmap.end(),
+ "Expected to find dummy_min_max Value in the subgraph to be replaced.");
+ auto dummy_min_max =
+ graph_rewrite_helper::getIValue("dummy_min_max", match_vmap, vmap);
+
+ auto is_fusable = !dummy_min_max || dummy_min_max.value().isNone();
+
+ // Also check if the output_min and output_max values are actually constant.
+ // If hardtanh's min/max Value's are not actually constants, we will end up
+ // rerouting those values to prepack op. And if they are not constants
+ // we will not be able to remove prepacking ops.
+ if (vmap.find("output_min") != vmap.end()) {
+ // aten::relu pattern does not have output_min/output_max.
+ // aten::hardtanh/_ does.
+ TORCH_CHECK(
+ vmap.find("output_max") != vmap.end(),
+ "Expected to find output_max as well given "
+ "output_min exist in pattern graph.");
+ // If output_min/max are not constant, we get c10::nullopt.
+ auto output_min =
+ graph_rewrite_helper::getIValue("output_min", match_vmap, vmap);
+ auto output_max =
+ graph_rewrite_helper::getIValue("output_max", match_vmap, vmap);
+ is_fusable =
+ is_fusable && (output_min.has_value() && output_max.has_value());
+ }
+
+ return is_fusable;
+}
+
} // namespace graph_rewrite_helper
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/passes/graph_rewrite_helper.h b/torch/csrc/jit/passes/graph_rewrite_helper.h
index 5d64228..00ce273 100644
--- a/torch/csrc/jit/passes/graph_rewrite_helper.h
+++ b/torch/csrc/jit/passes/graph_rewrite_helper.h
@@ -19,6 +19,10 @@
const std::unordered_map<std::string, Value*>& vmap);
void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph);
+bool isClampFusable(
+ const Match& match,
+ const std::unordered_map<std::string, Value*>& vmap);
+
using MatchFilter = std::function<
bool(const Match&, const std::unordered_map<std::string, Value*>&)>;
diff --git a/torch/csrc/jit/passes/vulkan_rewrite.cpp b/torch/csrc/jit/passes/vulkan_rewrite.cpp
new file mode 100644
index 0000000..6d0dc6c
--- /dev/null
+++ b/torch/csrc/jit/passes/vulkan_rewrite.cpp
@@ -0,0 +1,205 @@
+#include <ATen/core/jit_type.h>
+#ifdef USE_VULKAN
+#include <ATen/native/vulkan/VulkanOpContext.h>
+#endif
+
+#include <torch/csrc/jit/ir/ir.h>
+#include <torch/csrc/jit/ir/subgraph_matcher.h>
+#include <torch/csrc/jit/passes/constant_pooling.h>
+#include <torch/csrc/jit/passes/fold_conv_bn.h>
+#include <torch/csrc/jit/passes/freeze_module.h>
+#include <torch/csrc/jit/passes/fuse_linear.h>
+#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
+#include <torch/csrc/jit/passes/prepack_folding.h>
+#include <torch/csrc/jit/passes/remove_dropout.h>
+#include <torch/csrc/jit/passes/subgraph_rewrite.h>
+#include <torch/csrc/jit/passes/vulkan_rewrite.h>
+
+namespace torch {
+namespace jit {
+
+#ifdef USE_VULKAN
+
+namespace {
+
+void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
+ graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
+
+ std::string conv_2d_pattern = R"(
+ graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
+ %r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
+ return (%r) )";
+
+ std::string prepacked_ops_conv2d_pattern = R"(
+ graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
+ %output_min_max : None = prim::Constant()
+ %packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
+ %weight, %bias, %stride, %padding, %dilation, %groups,
+ %output_min_max, %output_min_max)
+ %r = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
+ return (%r) )";
+
+ SubgraphRewriter rewriter;
+ rewriter.RegisterRewritePattern(
+ conv_2d_pattern, prepacked_ops_conv2d_pattern);
+ rewriter.runOnGraph(graph);
+}
+
+void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
+ SubgraphRewriter rewriter;
+
+ std::string conv2d_prepack_run_hardtanh_fused = R"(
+ graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+ %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
+ %packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dOpContext = vulkan_prepack::conv2d_clamp_prepack(
+ %weight, %bias, %stride, %padding, %dilation, %groups,
+ %output_min, %output_max)
+ %r = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
+ return (%r) )";
+
+ std::string conv2d_prepack_run_hardtanh = R"(
+ graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+ %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
+ %packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
+ %weight, %bias, %stride, %padding, %dilation, %groups,
+ %dummy_min_max, %dummy_min_max)
+ %conv2d_res = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
+ %r = aten::hardtanh(%conv2d_res, %output_min, %output_max)
+ return (%r) )";
+
+ rewriter.RegisterRewritePattern(
+ conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused);
+
+ std::string conv2d_prepack_run_hardtanh_inplace = R"(
+ graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+ %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
+ %packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
+ %weight, %bias, %stride, %padding, %dilation, %groups,
+ %dummy_min_max, %dummy_min_max)
+ %conv2d_res = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
+ %r = aten::hardtanh_(%conv2d_res, %output_min, %output_max)
+ return (%r) )";
+
+ rewriter.RegisterRewritePattern(
+ conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
+
+ rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
+}
+
+void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
+ SubgraphRewriter rewriter;
+
+ std::string conv2d_prepack_run_relu_fused = R"(
+ graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+ %dilation:int[], %groups:int, %dummy_min_max):
+ %output_min: float = prim::Constant[value=0.0]()
+ %output_max: None = prim::Constant()
+ %packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dOpContext = vulkan_prepack::conv2d_clamp_prepack(
+ %weight, %bias, %stride, %padding, %dilation, %groups,
+ %output_min, %output_max)
+ %r = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
+ return (%r) )";
+
+ std::string conv2d_prepack_run_relu = R"(
+ graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+ %dilation:int[], %groups:int, %dummy_min_max):
+ %packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
+ %weight, %bias, %stride, %padding, %dilation, %groups,
+ %dummy_min_max, %dummy_min_max)
+ %conv2d_res = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
+ %r = aten::relu(%conv2d_res)
+ return (%r) )";
+
+ rewriter.RegisterRewritePattern(
+ conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused);
+
+ std::string conv2d_prepack_run_relu_inplace = R"(
+ graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+ %dilation:int[], %groups:int, %dummy_min_max):
+ %packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
+ %weight, %bias, %stride, %padding, %dilation, %groups,
+ %dummy_min_max, %dummy_min_max)
+ %conv2d_res = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
+ %r = aten::relu_(%conv2d_res)
+ return (%r) )";
+
+ rewriter.RegisterRewritePattern(
+ conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
+ rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
+}
+
+} // namespace
+
+void vulkanInsertPrePackedOps(std::shared_ptr<Graph>& graph) {
+ insertPrePackedConv2dOp(graph);
+}
+
+void vulkanInsertPrePackedOps(script::Module& module) {
+ for (auto& method : module.get_methods()) {
+ auto graph = method.graph();
+ vulkanInsertPrePackedOps(graph);
+ }
+ for (script::Module m : module.children()) {
+ vulkanInsertPrePackedOps(m);
+ }
+}
+
+void vulkanFusePrePackedConvWithClamp(script::Module& module) {
+ auto graph = module.get_method("forward").graph();
+ fuseReluWithPackedOps(graph);
+ fuseHardtanhWithPackedOps(graph);
+}
+
+void vulkanFoldPrePackingOps(script::Module& m) {
+ PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool {
+ return (
+ n->kind() ==
+ Symbol::fromQualString("vulkan_prepack::conv2d_clamp_prepack"));
+ };
+ PrePackingOpsFolder(m, filter_fn, "prepack_folding");
+}
+
+script::Module vulkanOptimizeForMobile(const script::Module& m) {
+ auto cloned_module = m.clone();
+ cloned_module.eval();
+ cloned_module = FoldConvBatchNorm2d(cloned_module);
+ vulkanInsertPrePackedOps(cloned_module);
+ cloned_module = freeze_module(cloned_module);
+ vulkanFusePrePackedConvWithClamp(cloned_module);
+ vulkanFoldPrePackingOps(cloned_module);
+ removeDropout(cloned_module);
+ return cloned_module;
+}
+
+#else
+
+void vulkanInsertPrePackedOps(std::shared_ptr<Graph>& graph) {
+ TORCH_INTERNAL_ASSERT(
+ "Vulkan is not enabled. Please build with USE_VULKAN=1");
+}
+
+void vulkanInsertPrePackedOps(script::Module& module) {
+ TORCH_INTERNAL_ASSERT(
+ "Vulkan is not enabled. Please build with USE_VULKAN=1");
+}
+
+void vulkanFusePrePackedConvWithClamp(script::Module& module) {
+ TORCH_INTERNAL_ASSERT(
+ "Vulkan is not enabled. Please build with USE_VULKAN=1");
+}
+
+void vulkanFoldPrePackingOps(script::Module& m) {
+ TORCH_INTERNAL_ASSERT(
+ "Vulkan is not enabled. Please build with USE_VULKAN=1");
+}
+
+script::Module vulkanOptimizeForMobile(const script::Module& module) {
+ TORCH_INTERNAL_ASSERT(
+ "Mobile optimizaiton only available with Vulkan at the moment. "
+ "Vulkan is not enabled. Please build with USE_VULKAN=1");
+ return module;
+}
+
+#endif
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/passes/vulkan_rewrite.h b/torch/csrc/jit/passes/vulkan_rewrite.h
new file mode 100644
index 0000000..31b9f1b
--- /dev/null
+++ b/torch/csrc/jit/passes/vulkan_rewrite.h
@@ -0,0 +1,14 @@
+#pragma once
+
+#include <torch/csrc/jit/api/module.h>
+#include <torch/csrc/jit/ir/ir.h>
+
+namespace torch {
+namespace jit {
+TORCH_API void vulkanInsertPrePackedOps(std::shared_ptr<Graph>& graph);
+TORCH_API void vulkanInsertPrePackedOps(script::Module& module);
+TORCH_API void vulkanFusePrePackedConvWithClamp(script::Module& module);
+TORCH_API void vulkanFoldPrePackingOps(script::Module& module);
+TORCH_API script::Module vulkanOptimizeForMobile(const script::Module& module);
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.cpp b/torch/csrc/jit/passes/xnnpack_rewrite.cpp
index 8d8888e..153b4e2 100644
--- a/torch/csrc/jit/passes/xnnpack_rewrite.cpp
+++ b/torch/csrc/jit/passes/xnnpack_rewrite.cpp
@@ -93,41 +93,6 @@
rewriter.runOnGraph(graph);
}
-bool isClampFusable(
- const Match& match,
- const std::unordered_map<std::string, Value*>& vmap) {
- const auto& match_vmap = match.values_map;
- TORCH_CHECK(
- vmap.find("dummy_min_max") != vmap.end(),
- "Expected to find dummy_min_max Value in the subgraph to be replaced.");
- auto dummy_min_max =
- graph_rewrite_helper::getIValue("dummy_min_max", match_vmap, vmap);
-
- auto is_fusable = !dummy_min_max || dummy_min_max.value().isNone();
-
- // Also check if the output_min and output_max values are actually constant.
- // If hardtanh's min/max Value's are not actually constants, we will end up
- // rerouting those values to prepack op. And if they are not constants
- // we will not be able to remove prepacking ops.
- if (vmap.find("output_min") != vmap.end()) {
- // aten::relu pattern does not have output_min/output_max.
- // aten::hardtanh/_ does.
- TORCH_CHECK(
- vmap.find("output_max") != vmap.end(),
- "Expected to find output_max as well given "
- "output_min exist in pattern graph.");
- // If output_min/max are not constant, we get c10::nullopt.
- auto output_min =
- graph_rewrite_helper::getIValue("output_min", match_vmap, vmap);
- auto output_max =
- graph_rewrite_helper::getIValue("output_max", match_vmap, vmap);
- is_fusable =
- is_fusable && (output_min.has_value() && output_max.has_value());
- }
-
- return is_fusable;
-}
-
void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
SubgraphRewriter rewriter;
@@ -194,7 +159,7 @@
rewriter.RegisterRewritePattern(
conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
- rewriter.runOnGraph(graph, isClampFusable);
+ rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
}
void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
@@ -266,7 +231,7 @@
linear_prepack_run_relu_inplace, linear_prepack_run_relu_fused);
rewriter.RegisterRewritePattern(
conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
- rewriter.runOnGraph(graph, isClampFusable);
+ rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
}
void runCanonicalOptimizations(script::Module& module) {
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 14aace4..910d5cd 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -54,6 +54,7 @@
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
+#include <torch/csrc/jit/passes/vulkan_rewrite.h>
#include <torch/csrc/jit/passes/xnnpack_rewrite.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/jit/python/python_arg_flatten.h>
@@ -542,6 +543,31 @@
return optimizeForMobile(module, optimization_blacklist);
})
.def(
+ "_jit_pass_vulkan_insert_prepacked_ops",
+ [](std::shared_ptr<Graph>& graph) {
+ return vulkanInsertPrePackedOps(graph);
+ })
+ .def(
+ "_jit_pass_vulkan_insert_prepacked_ops",
+ [](script::Module& module) {
+ return vulkanInsertPrePackedOps(module);
+ })
+ .def(
+ "_jit_pass_vulkan_fuse_clamp_w_prepacked_conv",
+ [](script::Module& module) {
+ return vulkanFusePrePackedConvWithClamp(module);
+ })
+ .def(
+ "_jit_pass_vulkan_fold_prepacking_ops",
+ [](script::Module& module) {
+ return vulkanFoldPrePackingOps(module);
+ })
+ .def(
+ "_jit_pass_vulkan_optimize_for_mobile",
+ [](script::Module& module) {
+ return vulkanOptimizeForMobile(module);
+ })
+ .def(
"_jit_pass_onnx_unpack_quantized_weights",
[](std::shared_ptr<Graph>& graph,
std::map<std::string, IValue>& paramsDict) {