[quant] dequantize support list and tuple of tensors (#41079)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/41079

Test Plan: Imported from OSS

Differential Revision: D22420700

fbshipit-source-id: bc4bf0fb47dcf8b94b11fbdc91e8d5a75142b7be
diff --git a/test/quantization/test_quantize_jit.py b/test/quantization/test_quantize_jit.py
index 17642d0..ba4b68a 100644
--- a/test/quantization/test_quantize_jit.py
+++ b/test/quantization/test_quantize_jit.py
@@ -2221,6 +2221,25 @@
                        .run(m.graph)
 
     @skipIfNoFBGEMM
+    def test_dequantize_tuple(self):
+        """ Make sure dequantize can support Tuple of tensor
+        """
+        class M(torch.nn.Module):
+            def __init__(self):
+                super(M, self).__init__()
+                self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
+                self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
+
+            def forward(self, x):
+                # type: (Tensor) -> Tuple[Tensor, Tensor]
+                x1 = self.conv1(x)
+                x2 = self.conv2(x)
+                return x1, x2
+
+        for tracing in [True, False]:
+            self.checkGraphModeOp(M(), self.img_data_2d, "quantized::conv2d", tracing)
+
+    @skipIfNoFBGEMM
     def test_clamp(self):
         class M(torch.nn.Module):
             def __init__(self):
diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp
index 75c852e..f5fe731 100644
--- a/torch/csrc/jit/passes/quantization/helper.cpp
+++ b/torch/csrc/jit/passes/quantization/helper.cpp
@@ -301,14 +301,28 @@
     }
     return inputs;
   } else if (n->kind() == prim::ListUnpack || n->kind() == prim::TupleUnpack) {
-    return {n->input(0)};
+    // only propagate dequantize for Tensor
+    if (v->type()->isSubtypeOf(TensorType::get())) {
+      return {n->input(0)};
+    } else {
+      return {};
+    }
   } else if (
-      n->kind() == prim::ListConstruct || n->kind() == prim::TupleConstruct) {
+      n->kind() == prim::ListConstruct &&
+      v->type()->isSubtypeOf(ListType::ofTensors())) {
     std::vector<Value*> inputs;
     for (auto* v : n->inputs()) {
       inputs.push_back(v);
     }
     return inputs;
+  } else if (n->kind() == prim::TupleConstruct) {
+    std::vector<Value*> inputs;
+    for (auto* input : n->inputs()) {
+      if (input->type()->isSubtypeOf(TensorType::get())) {
+        inputs.push_back(input);
+      }
+    }
+    return inputs;
   } else if (isListAdd(n)) {
     // We need to propagate dequantize of n->input(0) if it is
     // not an empty list
diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp
index e1179c6..138fa48 100644
--- a/torch/csrc/jit/runtime/register_prim_ops.cpp
+++ b/torch/csrc/jit/runtime/register_prim_ops.cpp
@@ -424,6 +424,18 @@
            handler(ss.str());
          },
          aliasAnalysisSpecialCase()),
+     Operator(
+         "aten::dequantize.tensor(Tensor qtensor) -> Tensor",
+         [](Stack* stack) {
+           at::Tensor qtensor;
+           pop(stack, qtensor);
+           push(stack, at::dequantize(qtensor));
+         },
+         aliasAnalysisFromSchema()),
+     Operator(
+         "aten::dequantize.any(Any tensors) -> Any",
+         [](Stack* stack) { dequantize(*stack); },
+         aliasAnalysisFromSchema()),
      DEFINE_STRING_OP(aten::add, a + b, str),
      DEFINE_COMPARISON_OP(aten::eq, a == b),
      DEFINE_COMPARISON_OP(aten::ne, a != b),
diff --git a/torch/csrc/jit/runtime/vararg_functions.cpp b/torch/csrc/jit/runtime/vararg_functions.cpp
index 91bde98..2654e5f 100644
--- a/torch/csrc/jit/runtime/vararg_functions.cpp
+++ b/torch/csrc/jit/runtime/vararg_functions.cpp
@@ -1,4 +1,5 @@
 #include <torch/csrc/jit/runtime/vararg_functions.h>
+#include <ATen/ATen.h>
 
 namespace torch {
 namespace jit {
@@ -124,5 +125,36 @@
   push(stack, c10::ivalue::Tuple::create(std::move(output_elems)));
 }
 
+void dequantize(Stack& stack) {
+  auto iv = pop(stack);
+  if (iv.isTuple()) {
+    auto tuple = iv.toTuple();
+    auto elems = tuple->elements();
+    std::vector<IValue> output_elems;
+    output_elems.reserve(elems.size());
+    for (size_t i = 0; i < elems.size(); ++i) {
+      if (elems[i].isTensor()) {
+        output_elems.emplace_back(at::dequantize(elems[i].toTensor()));
+      } else {
+        output_elems.emplace_back(elems[i]);
+      }
+    }
+    push(stack, c10::ivalue::Tuple::create(std::move(output_elems)));
+  } else if (iv.isTensorList()) {
+    auto elems = iv.toTensorList();
+    auto output_list = c10::impl::GenericList(elems.elementType());
+    for (size_t i = 0; i < elems.size(); ++i) {
+      output_list.emplace_back(at::dequantize(elems[i]));
+    }
+    push(stack, std::move(output_list));
+  } else {
+    TORCH_CHECK(
+        false,
+        "Unsupported type in dequantize, only List[Tensor] and \
+ Tuple[Tensor or other types] are supported, got type:",
+        toString(iv.type()));
+  }
+}
+
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/runtime/vararg_functions.h b/torch/csrc/jit/runtime/vararg_functions.h
index fa3820a..4941172 100644
--- a/torch/csrc/jit/runtime/vararg_functions.h
+++ b/torch/csrc/jit/runtime/vararg_functions.h
@@ -30,5 +30,7 @@
 
 void tupleSlice(Stack& stack, size_t begin, size_t end);
 
+void dequantize(Stack& stack);
+
 } // namespace jit
 } // namespace torch