fix asan failure for module freezing in conv bn folding (#42739)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42739

This is a test case which fails with ASAN on at the module freezing
step.

Test Plan:
```
USE_ASAN=1 USE_CUDA=0 python setup.py develop
LD_PRELOAD=/usr/lib64/libasan.so.4 python test/test_mobile_optimizer.py TestOptimizer.test_optimize_for_mobile_asan

// output tail: https://gist.github.com/vkuzo/7a0018b9e10ffe64dab0ac7381479f23
```

Imported from OSS

Reviewed By: kimishpatel

Differential Revision: D23005962

fbshipit-source-id: b7d4492e989af7c2e22197c16150812bd2dda7cc
diff --git a/test/test_mobile_optimizer.py b/test/test_mobile_optimizer.py
index 20a73b2..1e2387e 100644
--- a/test/test_mobile_optimizer.py
+++ b/test/test_mobile_optimizer.py
@@ -1,11 +1,13 @@
 import unittest
 import torch
+import torch.nn as nn
 import torch.backends.xnnpack
 import torch.utils.bundled_inputs
 from torch.testing._internal.jit_utils import get_forward, get_forward_graph
 from torch.utils.mobile_optimizer import *
 from torch.nn import functional as F
 from torch._C import MobileOptimizerType
+from torch.testing._internal.common_quantized import override_quantized_engine
 
 FileCheck = torch._C.FileCheck
 
@@ -155,6 +157,52 @@
         preserveThis = getattr(opt_m, "preserveThis", None)
         self.assertNotEqual(preserveThis, None)
 
+    @unittest.skipUnless(torch.backends.xnnpack.enabled,
+                         " XNNPACK must be enabled for these tests."
+                         " Please build with USE_XNNPACK=1.")
+    def test_quantized_conv_no_asan_failures(self):
+        # There were ASAN failures when fold_conv_bn was run on
+        # already quantized conv modules. Verifying that this does
+        # not happen again.
+
+        if 'qnnpack' not in torch.backends.quantized.supported_engines:
+            return
+
+        class Child(nn.Module):
+            def __init__(self):
+                super(Child, self).__init__()
+                self.conv2 = nn.Conv2d(1, 1, 1)
+
+            def forward(self, x):
+                x = self.conv2(x)
+                return x
+
+        class Parent(nn.Module):
+            def __init__(self):
+                super(Parent, self).__init__()
+                self.quant = torch.quantization.QuantStub()
+                self.conv1 = nn.Conv2d(1, 1, 1)
+                self.child = Child()
+                self.dequant = torch.quantization.DeQuantStub()
+
+            def forward(self, x):
+                x = self.quant(x)
+                x = self.conv1(x)
+                x = self.child(x)
+                x = self.dequant(x)
+                return x
+
+        with override_quantized_engine('qnnpack'):
+            model = Parent()
+            model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
+            torch.quantization.prepare(model, inplace=True)
+            model(torch.randn(4, 1, 4, 4))
+            torch.quantization.convert(model, inplace=True)
+            model = torch.jit.script(model)
+            # this line should not have ASAN failures
+            model_optim = optimize_for_mobile(model)
+            self.assertFalse(hasattr(model_optim.conv1, "bias"))
+            self.assertFalse(hasattr(model_optim.child.conv2, "bias"))
 
     def test_generate_mobile_module_lints(self):
         class MyTestModule(torch.nn.Module):
diff --git a/torch/csrc/jit/passes/fold_conv_bn.cpp b/torch/csrc/jit/passes/fold_conv_bn.cpp
index 211fd90..c6ffe39 100644
--- a/torch/csrc/jit/passes/fold_conv_bn.cpp
+++ b/torch/csrc/jit/passes/fold_conv_bn.cpp
@@ -59,13 +59,15 @@
 
 void addBiasForConvIfNone(Module& module, const std::string& pattern_name) {
   auto t = module.type()->expect<ClassType>();
-  auto real_typename = t->name()->qualifiedName();
-  if (real_typename.size() >= pattern_name.size() &&
-      (0 ==
-       real_typename.compare(
-           real_typename.size() - pattern_name.size(),
-           pattern_name.size(),
-           pattern_name))) {
+
+  const std::string real_typename = t->name()->qualifiedName();
+  const std::string demangled_typename = removeTorchMangle(real_typename);
+  bool is_floating_point_conv =
+      ((demangled_typename == "__torch__.torch.nn.modules.conv.Conv1d") ||
+       (demangled_typename == "__torch__.torch.nn.modules.conv.Conv2d") ||
+       (demangled_typename == "__torch__.torch.nn.modules.conv.Conv3d"));
+
+  if (is_floating_point_conv) {
     if (!t->hasAttribute("bias")) {
       auto optional_tensor_type = OptionalType::create(TensorType::get());
       t->addAttribute("bias", optional_tensor_type, true);
diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp
index dcaf92a..0580f41 100644
--- a/torch/csrc/jit/passes/quantization/helper.cpp
+++ b/torch/csrc/jit/passes/quantization/helper.cpp
@@ -609,13 +609,16 @@
   return v->type()->cast<FunctionType>() && getFuncName(v) == functional;
 }
 
+std::string removeTorchMangle(const std::string& orig_name) {
+  static std::regex mangle_re("\\.___torch_mangle_\\d+");
+  auto qualified_name = std::regex_replace(orig_name, mangle_re, "");
+  return qualified_name;
+}
+
 c10::optional<std::string> getModuleName(Value* value) {
   auto type = value->type()->cast<ClassType>();
   if (type && type->name()) {
-    static std::regex mangle_re("\\.___torch_mangle_\\d+");
-    auto qualified_name =
-        std::regex_replace(type->name()->qualifiedName(), mangle_re, "");
-    return qualified_name;
+    return removeTorchMangle(type->name()->qualifiedName());
   }
   return c10::nullopt;
 }
diff --git a/torch/csrc/jit/passes/quantization/helper.h b/torch/csrc/jit/passes/quantization/helper.h
index ee2d3b5..7100d4c 100644
--- a/torch/csrc/jit/passes/quantization/helper.h
+++ b/torch/csrc/jit/passes/quantization/helper.h
@@ -44,6 +44,12 @@
 // Check if value is the input of the graph
 TORCH_API bool hitGraphInput(Value* value);
 
+// Converts a mangled name, such as
+//   __torch__.torch.nn.quantized.modules.conv.___torch_mangle_7.Conv2d
+// into an unmangled name, such as
+//   __torch__.torch.nn.quantized.modules.conv.Conv2d
+TORCH_API std::string removeTorchMangle(const std::string& orig_name);
+
 // Return the module name that corresponds to the value.
 TORCH_API c10::optional<std::string> getModuleName(Value* value);