Add type inference for dequantization.tensors (#49517)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49517
We should add concrete type info for Tensor List case as well.
Test Plan: ci
Reviewed By: qizzzh
Differential Revision: D25599223
fbshipit-source-id: 3614e9ec25fc963a8d6a0bd641735fcca6c87032
diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp
index d9bffa7..f23b09d 100644
--- a/torch/csrc/jit/runtime/register_prim_ops.cpp
+++ b/torch/csrc/jit/runtime/register_prim_ops.cpp
@@ -679,6 +679,12 @@
push(stack, x != y);
},
aliasAnalysisFromSchema()),
+ // We define aten::dequantize in both native_functions.yaml and here,
+ // however, aten::dequantize.any defined here overrides
+ // aten::dequantize.tensors in native_functions.yaml. The variants here
+ // are only for graph mode quantization, and they should be removed once
+ // we deprecate graph mode quantization, and use the variants in
+ // native_functions.yaml.
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA(
"aten::dequantize.tensor(Tensor qtensor) -> Tensor"),
@@ -689,6 +695,14 @@
},
aliasAnalysisFromSchema()),
OperatorGenerator(
+ TORCH_SELECTIVE_SCHEMA(
+ "aten::dequantize.list(Tensor[] qtensors) -> Tensor[]"),
+ [](Stack* stack) {
+ auto qtensors = pop(stack).toTensorVector();
+ push(stack, at::dequantize(qtensors));
+ },
+ aliasAnalysisFromSchema()),
+ OperatorGenerator(
TORCH_SELECTIVE_SCHEMA("aten::dequantize.any(Any tensors) -> Any"),
[](Stack* stack) { dequantize(*stack); },
aliasAnalysisFromSchema()),