[quant] dequantize support list and tuple of tensors (#41079)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/41079
Test Plan: Imported from OSS
Differential Revision: D22420700
fbshipit-source-id: bc4bf0fb47dcf8b94b11fbdc91e8d5a75142b7be
diff --git a/test/quantization/test_quantize_jit.py b/test/quantization/test_quantize_jit.py
index 17642d0..ba4b68a 100644
--- a/test/quantization/test_quantize_jit.py
+++ b/test/quantization/test_quantize_jit.py
@@ -2221,6 +2221,25 @@
.run(m.graph)
@skipIfNoFBGEMM
+ def test_dequantize_tuple(self):
+ """ Make sure dequantize can support Tuple of tensor
+ """
+ class M(torch.nn.Module):
+ def __init__(self):
+ super(M, self).__init__()
+ self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
+ self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
+
+ def forward(self, x):
+ # type: (Tensor) -> Tuple[Tensor, Tensor]
+ x1 = self.conv1(x)
+ x2 = self.conv2(x)
+ return x1, x2
+
+ for tracing in [True, False]:
+ self.checkGraphModeOp(M(), self.img_data_2d, "quantized::conv2d", tracing)
+
+ @skipIfNoFBGEMM
def test_clamp(self):
class M(torch.nn.Module):
def __init__(self):
diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp
index 75c852e..f5fe731 100644
--- a/torch/csrc/jit/passes/quantization/helper.cpp
+++ b/torch/csrc/jit/passes/quantization/helper.cpp
@@ -301,14 +301,28 @@
}
return inputs;
} else if (n->kind() == prim::ListUnpack || n->kind() == prim::TupleUnpack) {
- return {n->input(0)};
+ // only propagate dequantize for Tensor
+ if (v->type()->isSubtypeOf(TensorType::get())) {
+ return {n->input(0)};
+ } else {
+ return {};
+ }
} else if (
- n->kind() == prim::ListConstruct || n->kind() == prim::TupleConstruct) {
+ n->kind() == prim::ListConstruct &&
+ v->type()->isSubtypeOf(ListType::ofTensors())) {
std::vector<Value*> inputs;
for (auto* v : n->inputs()) {
inputs.push_back(v);
}
return inputs;
+ } else if (n->kind() == prim::TupleConstruct) {
+ std::vector<Value*> inputs;
+ for (auto* input : n->inputs()) {
+ if (input->type()->isSubtypeOf(TensorType::get())) {
+ inputs.push_back(input);
+ }
+ }
+ return inputs;
} else if (isListAdd(n)) {
// We need to propagate dequantize of n->input(0) if it is
// not an empty list
diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp
index e1179c6..138fa48 100644
--- a/torch/csrc/jit/runtime/register_prim_ops.cpp
+++ b/torch/csrc/jit/runtime/register_prim_ops.cpp
@@ -424,6 +424,18 @@
handler(ss.str());
},
aliasAnalysisSpecialCase()),
+ Operator(
+ "aten::dequantize.tensor(Tensor qtensor) -> Tensor",
+ [](Stack* stack) {
+ at::Tensor qtensor;
+ pop(stack, qtensor);
+ push(stack, at::dequantize(qtensor));
+ },
+ aliasAnalysisFromSchema()),
+ Operator(
+ "aten::dequantize.any(Any tensors) -> Any",
+ [](Stack* stack) { dequantize(*stack); },
+ aliasAnalysisFromSchema()),
DEFINE_STRING_OP(aten::add, a + b, str),
DEFINE_COMPARISON_OP(aten::eq, a == b),
DEFINE_COMPARISON_OP(aten::ne, a != b),
diff --git a/torch/csrc/jit/runtime/vararg_functions.cpp b/torch/csrc/jit/runtime/vararg_functions.cpp
index 91bde98..2654e5f 100644
--- a/torch/csrc/jit/runtime/vararg_functions.cpp
+++ b/torch/csrc/jit/runtime/vararg_functions.cpp
@@ -1,4 +1,5 @@
#include <torch/csrc/jit/runtime/vararg_functions.h>
+#include <ATen/ATen.h>
namespace torch {
namespace jit {
@@ -124,5 +125,36 @@
push(stack, c10::ivalue::Tuple::create(std::move(output_elems)));
}
+void dequantize(Stack& stack) {
+ auto iv = pop(stack);
+ if (iv.isTuple()) {
+ auto tuple = iv.toTuple();
+ auto elems = tuple->elements();
+ std::vector<IValue> output_elems;
+ output_elems.reserve(elems.size());
+ for (size_t i = 0; i < elems.size(); ++i) {
+ if (elems[i].isTensor()) {
+ output_elems.emplace_back(at::dequantize(elems[i].toTensor()));
+ } else {
+ output_elems.emplace_back(elems[i]);
+ }
+ }
+ push(stack, c10::ivalue::Tuple::create(std::move(output_elems)));
+ } else if (iv.isTensorList()) {
+ auto elems = iv.toTensorList();
+ auto output_list = c10::impl::GenericList(elems.elementType());
+ for (size_t i = 0; i < elems.size(); ++i) {
+ output_list.emplace_back(at::dequantize(elems[i]));
+ }
+ push(stack, std::move(output_list));
+ } else {
+ TORCH_CHECK(
+ false,
+ "Unsupported type in dequantize, only List[Tensor] and \
+ Tuple[Tensor or other types] are supported, got type:",
+ toString(iv.type()));
+ }
+}
+
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/runtime/vararg_functions.h b/torch/csrc/jit/runtime/vararg_functions.h
index fa3820a..4941172 100644
--- a/torch/csrc/jit/runtime/vararg_functions.h
+++ b/torch/csrc/jit/runtime/vararg_functions.h
@@ -30,5 +30,7 @@
void tupleSlice(Stack& stack, size_t begin, size_t end);
+void dequantize(Stack& stack);
+
} // namespace jit
} // namespace torch