fix asan failure for module freezing in conv bn folding (#42739)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42739
This is a test case which fails with ASAN on at the module freezing
step.
Test Plan:
```
USE_ASAN=1 USE_CUDA=0 python setup.py develop
LD_PRELOAD=/usr/lib64/libasan.so.4 python test/test_mobile_optimizer.py TestOptimizer.test_optimize_for_mobile_asan
// output tail: https://gist.github.com/vkuzo/7a0018b9e10ffe64dab0ac7381479f23
```
Imported from OSS
Reviewed By: kimishpatel
Differential Revision: D23005962
fbshipit-source-id: b7d4492e989af7c2e22197c16150812bd2dda7cc
diff --git a/test/test_mobile_optimizer.py b/test/test_mobile_optimizer.py
index 20a73b2..1e2387e 100644
--- a/test/test_mobile_optimizer.py
+++ b/test/test_mobile_optimizer.py
@@ -1,11 +1,13 @@
import unittest
import torch
+import torch.nn as nn
import torch.backends.xnnpack
import torch.utils.bundled_inputs
from torch.testing._internal.jit_utils import get_forward, get_forward_graph
from torch.utils.mobile_optimizer import *
from torch.nn import functional as F
from torch._C import MobileOptimizerType
+from torch.testing._internal.common_quantized import override_quantized_engine
FileCheck = torch._C.FileCheck
@@ -155,6 +157,52 @@
preserveThis = getattr(opt_m, "preserveThis", None)
self.assertNotEqual(preserveThis, None)
+ @unittest.skipUnless(torch.backends.xnnpack.enabled,
+ " XNNPACK must be enabled for these tests."
+ " Please build with USE_XNNPACK=1.")
+ def test_quantized_conv_no_asan_failures(self):
+ # There were ASAN failures when fold_conv_bn was run on
+ # already quantized conv modules. Verifying that this does
+ # not happen again.
+
+ if 'qnnpack' not in torch.backends.quantized.supported_engines:
+ return
+
+ class Child(nn.Module):
+ def __init__(self):
+ super(Child, self).__init__()
+ self.conv2 = nn.Conv2d(1, 1, 1)
+
+ def forward(self, x):
+ x = self.conv2(x)
+ return x
+
+ class Parent(nn.Module):
+ def __init__(self):
+ super(Parent, self).__init__()
+ self.quant = torch.quantization.QuantStub()
+ self.conv1 = nn.Conv2d(1, 1, 1)
+ self.child = Child()
+ self.dequant = torch.quantization.DeQuantStub()
+
+ def forward(self, x):
+ x = self.quant(x)
+ x = self.conv1(x)
+ x = self.child(x)
+ x = self.dequant(x)
+ return x
+
+ with override_quantized_engine('qnnpack'):
+ model = Parent()
+ model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
+ torch.quantization.prepare(model, inplace=True)
+ model(torch.randn(4, 1, 4, 4))
+ torch.quantization.convert(model, inplace=True)
+ model = torch.jit.script(model)
+ # this line should not have ASAN failures
+ model_optim = optimize_for_mobile(model)
+ self.assertFalse(hasattr(model_optim.conv1, "bias"))
+ self.assertFalse(hasattr(model_optim.child.conv2, "bias"))
def test_generate_mobile_module_lints(self):
class MyTestModule(torch.nn.Module):
diff --git a/torch/csrc/jit/passes/fold_conv_bn.cpp b/torch/csrc/jit/passes/fold_conv_bn.cpp
index 211fd90..c6ffe39 100644
--- a/torch/csrc/jit/passes/fold_conv_bn.cpp
+++ b/torch/csrc/jit/passes/fold_conv_bn.cpp
@@ -59,13 +59,15 @@
void addBiasForConvIfNone(Module& module, const std::string& pattern_name) {
auto t = module.type()->expect<ClassType>();
- auto real_typename = t->name()->qualifiedName();
- if (real_typename.size() >= pattern_name.size() &&
- (0 ==
- real_typename.compare(
- real_typename.size() - pattern_name.size(),
- pattern_name.size(),
- pattern_name))) {
+
+ const std::string real_typename = t->name()->qualifiedName();
+ const std::string demangled_typename = removeTorchMangle(real_typename);
+ bool is_floating_point_conv =
+ ((demangled_typename == "__torch__.torch.nn.modules.conv.Conv1d") ||
+ (demangled_typename == "__torch__.torch.nn.modules.conv.Conv2d") ||
+ (demangled_typename == "__torch__.torch.nn.modules.conv.Conv3d"));
+
+ if (is_floating_point_conv) {
if (!t->hasAttribute("bias")) {
auto optional_tensor_type = OptionalType::create(TensorType::get());
t->addAttribute("bias", optional_tensor_type, true);
diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp
index dcaf92a..0580f41 100644
--- a/torch/csrc/jit/passes/quantization/helper.cpp
+++ b/torch/csrc/jit/passes/quantization/helper.cpp
@@ -609,13 +609,16 @@
return v->type()->cast<FunctionType>() && getFuncName(v) == functional;
}
+std::string removeTorchMangle(const std::string& orig_name) {
+ static std::regex mangle_re("\\.___torch_mangle_\\d+");
+ auto qualified_name = std::regex_replace(orig_name, mangle_re, "");
+ return qualified_name;
+}
+
c10::optional<std::string> getModuleName(Value* value) {
auto type = value->type()->cast<ClassType>();
if (type && type->name()) {
- static std::regex mangle_re("\\.___torch_mangle_\\d+");
- auto qualified_name =
- std::regex_replace(type->name()->qualifiedName(), mangle_re, "");
- return qualified_name;
+ return removeTorchMangle(type->name()->qualifiedName());
}
return c10::nullopt;
}
diff --git a/torch/csrc/jit/passes/quantization/helper.h b/torch/csrc/jit/passes/quantization/helper.h
index ee2d3b5..7100d4c 100644
--- a/torch/csrc/jit/passes/quantization/helper.h
+++ b/torch/csrc/jit/passes/quantization/helper.h
@@ -44,6 +44,12 @@
// Check if value is the input of the graph
TORCH_API bool hitGraphInput(Value* value);
+// Converts a mangled name, such as
+// __torch__.torch.nn.quantized.modules.conv.___torch_mangle_7.Conv2d
+// into an unmangled name, such as
+// __torch__.torch.nn.quantized.modules.conv.Conv2d
+TORCH_API std::string removeTorchMangle(const std::string& orig_name);
+
// Return the module name that corresponds to the value.
TORCH_API c10::optional<std::string> getModuleName(Value* value);