Add prelu op and module for quantized CPU backend (#73491)
Add prelu op and module for quantized CPU backend.
The PR includes:
- Quantized version of prelu op
- Native prelu kernel for quantized CPU
- Prelu modules in `nn` and `nn.quantized`
- FX support for prelu
- Unit tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73491
Approved by: https://github.com/jerryzh168
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index f7dee89..022755d 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3987,6 +3987,7 @@
MkldnnCPU: mkldnn_prelu
CPU: prelu_cpu
CUDA: prelu_cuda
+ QuantizedCPU: prelu_quantized_cpu
- func: prelu_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor)
variants: function, method
diff --git a/aten/src/ATen/native/quantized/cpu/QuantizedOps.h b/aten/src/ATen/native/quantized/cpu/QuantizedOps.h
index bfa1f1f..0391723 100644
--- a/aten/src/ATen/native/quantized/cpu/QuantizedOps.h
+++ b/aten/src/ATen/native/quantized/cpu/QuantizedOps.h
@@ -163,6 +163,9 @@
double /* eps */,
Tensor* /* Y */);
+using qprelu_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/,
+ const Tensor& /*qw*/);
+
DECLARE_DISPATCH(qadaptive_avg_pool2d_fn, qadaptive_avg_pool2d_nhwc_stub);
DECLARE_DISPATCH(qadaptive_avg_pool3d_fn, qadaptive_avg_pool3d_ndhwc_stub);
DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_relu_stub);
@@ -194,6 +197,7 @@
DECLARE_DISPATCH(qthreshold_fn, qthreshold_stub);
DECLARE_DISPATCH(qtopk_fn, qtopk_stub);
DECLARE_DISPATCH(qupsample_bilinear2d_fn, qupsample_bilinear2d_nhwc_stub);
+DECLARE_DISPATCH(qprelu_fn, qprelu_stub);
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
index d16e4d6..fc9d7a8 100644
--- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
+++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
@@ -652,6 +652,81 @@
});
}
+static void qprelu_out_kernel(Tensor& out,
+ const Tensor& qx,
+ const Tensor& qw) {
+ int32_t i_zp = static_cast<int32_t>(qx.q_zero_point());
+ float i_scale = static_cast<float>(qx.q_scale());
+
+ int32_t w_zp = static_cast<int32_t>(qw.q_zero_point());
+ float w_scale = static_cast<float>(qw.q_scale());
+
+ int32_t o_zp = static_cast<int32_t>(out.q_zero_point());
+ float o_scale = static_cast<float>(out.q_scale());
+ float o_inv_scale = 1.0f / o_scale;
+
+ float multiplier = i_scale * w_scale * o_inv_scale;
+
+ int64_t input_ndim = qx.dim();
+ TORCH_CHECK(input_ndim > 0, "qprelu: zero-dim input tensor is not allowed.");
+
+ // Helper to convert 1d tensors or scalar tensor to an nd tensor that broadcasts with input
+ // All elements go into the channel dimension
+ DimVector sizes(input_ndim, 1), strides(input_ndim, 0);
+ auto as_nd = [&](const Tensor& t) {
+ TORCH_INTERNAL_ASSERT(t.defined() && (t.dim() == 1 || t.dim() == 0));
+ sizes[1] = t.dim() == 1 ? t.sizes()[0] : 1;
+ strides[1] = t.dim() == 1 ? t.strides()[0] : 0;
+ return t.as_strided(sizes, strides);
+ };
+
+ auto qw_nd = as_nd(qw);
+
+ auto iter = TensorIteratorConfig()
+ .add_output(out)
+ .add_input(qx)
+ .add_input(qw_nd)
+ .build();
+
+ AT_DISPATCH_QINT_TYPES(out.scalar_type(), "qprelu", [&] {
+ using qVec = Vectorized<scalar_t>;
+ qVec i_zp_vec = qVec(static_cast<scalar_t>(i_zp));
+ qVec w_zp_vec = qVec(static_cast<scalar_t>(w_zp));
+
+ // Quantized one as weight
+ auto qw_one = at::native::quantize_val<scalar_t>(w_scale, w_zp, 1.0f);
+ qVec vec_qw_one = qVec(qw_one);
+ auto vec_qw_one_sub_zp = vec_qw_one.widening_subtract(w_zp_vec)[0];
+ int32_t qw_one_sub_zp = qw_one.val_ - w_zp;
+
+ cpu_kernel_vec(
+ iter,
+ [=](scalar_t val_qx, scalar_t val_qw) -> scalar_t {
+ int32_t qx_pos = std::max(static_cast<int32_t>(val_qx.val_), i_zp);
+ int32_t qx_neg = std::min(static_cast<int32_t>(val_qx.val_), i_zp);
+ int32_t qx_pos_sub_zp = qx_pos - i_zp;
+ int32_t qx_neg_sub_zp = qx_neg - i_zp;
+ int32_t qw_sub_zp = val_qw.val_ - w_zp;
+ auto qy_sub_zp = qx_pos_sub_zp * qw_one_sub_zp + qx_neg_sub_zp * qw_sub_zp;
+ return at::native::requantize_from_int<scalar_t>(
+ multiplier, o_zp, qy_sub_zp);
+ },
+ [=](qVec vec_qx, qVec vec_qw) -> qVec {
+ auto vec_qx_pos = vec_qx.maximum(i_zp_vec);
+ auto vec_qx_neg = vec_qx.minimum(i_zp_vec);
+ qVec::int_vec_return_type qx_pos_sub_zp = vec_qx_pos.widening_subtract(i_zp_vec);
+ qVec::int_vec_return_type qx_neg_sub_zp = vec_qx_neg.widening_subtract(i_zp_vec);
+ qVec::int_vec_return_type qw_sub_zp = vec_qw.widening_subtract(w_zp_vec);
+ qVec::int_vec_return_type qy_sub_zp;
+ for (const auto i : c10::irange(qVec::int_num_vecs())) {
+ qy_sub_zp[i] = qx_pos_sub_zp[i] * vec_qw_one_sub_zp + qx_neg_sub_zp[i] * qw_sub_zp[i];
+ }
+ return qVec::requantize_from_int(qy_sub_zp, multiplier, o_zp);
+ });
+ });
+
+}
+
void qgelu_kernel(const Tensor& qx, Tensor& qy, GeluType approximate) {
int64_t zero_point = qx.q_zero_point();
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
@@ -3694,6 +3769,7 @@
REGISTER_NO_AVX512_DISPATCH(qmul_stub);
REGISTER_NO_AVX512_DISPATCH(qrelu_leaky_stub);
REGISTER_NO_AVX512_DISPATCH(qrelu_stub);
+REGISTER_NO_AVX512_DISPATCH(qprelu_stub);
REGISTER_NO_AVX512_DISPATCH(qgelu_stub);
REGISTER_NO_AVX512_DISPATCH(qsigmoid_stub);
REGISTER_NO_AVX512_DISPATCH(qtanh_stub);
@@ -3748,6 +3824,7 @@
REGISTER_DISPATCH(qmul_stub, &qmul_kernel<false>);
REGISTER_DISPATCH(qrelu_leaky_stub, &leaky_qrelu_out_kernel);
REGISTER_DISPATCH(qrelu_stub, &qrelu_kernel);
+REGISTER_DISPATCH(qprelu_stub, &qprelu_out_kernel);
REGISTER_DISPATCH(qgelu_stub, &qgelu_kernel);
REGISTER_DISPATCH(qsigmoid_stub, &qsigmoid_kernel);
REGISTER_DISPATCH(qtanh_stub, &qtanh_kernel);
diff --git a/aten/src/ATen/native/quantized/cpu/qrelu.cpp b/aten/src/ATen/native/quantized/cpu/qrelu.cpp
index ad071d6..e4ca887 100644
--- a/aten/src/ATen/native/quantized/cpu/qrelu.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qrelu.cpp
@@ -17,6 +17,7 @@
DEFINE_DISPATCH(qrelu_stub);
DEFINE_DISPATCH(qrelu_leaky_stub);
+DEFINE_DISPATCH(qprelu_stub);
#ifdef USE_PYTORCH_QNNPACK
Tensor qnnpack_relu(Tensor input) {
@@ -134,6 +135,32 @@
return self;
}
+Tensor prelu_quantized_cpu_impl(const Tensor& self, const Tensor& weight,
+ double output_scale, int64_t output_zero_point) {
+ auto ndim = self.dim();
+ // for ndim < 1 or > 5, go to reference path
+ if (ndim > 5 || ndim < 1) {
+ auto x = self.dequantize();
+ auto y = at::prelu(x, weight);
+ return at::quantize_per_tensor(y, output_scale, output_zero_point, c10::kQUInt8);
+ }
+
+ auto qy = at::_empty_affine_quantized(self.sizes(),
+ at::device(kCPU)
+ .dtype(self.scalar_type()),
+ output_scale,
+ output_zero_point,
+ self.suggest_memory_format());
+
+ qprelu_stub(self.device().type(), qy, self, weight);
+
+ return qy;
+}
+
+Tensor prelu_quantized_cpu(const Tensor& self, const Tensor& weight) {
+ return prelu_quantized_cpu_impl(self, weight, self.q_scale(), self.q_zero_point());
+}
+
namespace {
Tensor quantized_relu6(const Tensor& qx) {
Tensor qy;
@@ -175,9 +202,17 @@
}
};
+class QPRelu final {
+ public:
+ static Tensor run(Tensor self, const Tensor& weight, double output_scale, int64_t output_zero_point) {
+ return prelu_quantized_cpu_impl(self, weight, output_scale, output_zero_point);
+ }
+};
+
TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
m.impl(TORCH_SELECTIVE_NAME("quantized::relu6"), TORCH_FN(QRelu6::run));
m.impl(TORCH_SELECTIVE_NAME("quantized::leaky_relu"), TORCH_FN(QLeakyRelu::run));
+ m.impl(TORCH_SELECTIVE_NAME("quantized::prelu"), TORCH_FN(QPRelu::run));
}
} // namespace
diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp
index ea5338d..a6ac4b3 100644
--- a/aten/src/ATen/native/quantized/library.cpp
+++ b/aten/src/ATen/native/quantized/library.cpp
@@ -187,6 +187,7 @@
m.def(TORCH_SELECTIVE_SCHEMA("quantized::max_pool2d(Tensor qx, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::relu6(Tensor qx, bool inplace=False) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::leaky_relu(Tensor qx, Scalar negative_slope, bool inplace, float output_scale, int output_zero_point) -> Tensor"));
+ m.def(TORCH_SELECTIVE_SCHEMA("quantized::prelu(Tensor qx, Tensor weight, float output_scale, int output_zero_point) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::sigmoid(Tensor qx, float output_scale, int output_zero_point) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::softmax(Tensor qx, int dim, float output_scale, int output_zero_point) -> Tensor"));
}
diff --git a/test/quantization/core/test_quantized_module.py b/test/quantization/core/test_quantized_module.py
index 7cbab3b..067b7b4 100644
--- a/test/quantization/core/test_quantized_module.py
+++ b/test/quantization/core/test_quantized_module.py
@@ -1011,6 +1011,30 @@
self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices,
offsets, set_qconfig, is_emb_bag=True, dtype=qdtype)
+ def test_prelu(self):
+ x = torch.randn((4, 4, 4, 4), dtype=torch.float)
+ qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.quint8)
+
+ # num_parameters = 1
+ prelu_module = nnq.PReLU(output_scale=1.0, output_zero_point=0, num_parameters=1)
+ w = torch.randn(1, dtype=torch.float)
+ qw = torch.quantize_per_tensor(w, 1.0, 0, dtype=torch.quint8)
+ prelu_module.set_weight(qw)
+ qy = prelu_module(qx)
+ qy_ref = torch.prelu(qx, qw)
+
+ self.assertEqual(qy_ref, qy,
+ msg="PReLU module API failed")
+
+ # num_parameters = num_channels
+ prelu_module = nnq.PReLU(output_scale=1.0, output_zero_point=0, num_parameters=4)
+ w = torch.randn(4, dtype=torch.float)
+ qw = torch.quantize_per_tensor(w, 1.0, 0, dtype=torch.quint8)
+ prelu_module.set_weight(qw)
+ qy = prelu_module(qx)
+ qy_ref = torch.prelu(qx, qw)
+ self.assertEqual(qy_ref, qy,
+ msg="PReLU module API failed")
class TestDynamicQuantizedModule(QuantizationTestCase):
def _test_qconv_impl(self, q_mod, dq_mod, dim, dtype, bias):
diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py
index 51d7473..4d9842b 100644
--- a/test/quantization/core/test_quantized_op.py
+++ b/test/quantization/core/test_quantized_op.py
@@ -469,6 +469,40 @@
self.assertEqual(qY.dequantize(), qY_hat.dequantize(),
msg="F.gelu failed ({} vs {})".format(qY, qY_hat))
+ """Tests the correctness of the quantized::prelu op."""
+ def test_qprelu(self):
+ shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4))
+ num_params = (0, 1) # 0: num_parameter = num_channels
+ dtypes = (torch.quint8, torch.qint8)
+ memory_formats = (torch.channels_last, torch.contiguous_format)
+ test_cases = itertools.product(shapes, num_params, dtypes, memory_formats)
+ for shape, num_param, dtype, memory_format in test_cases:
+ if memory_format == torch.channels_last and len(shape) != 4:
+ continue
+ X, scale, zero_point, torch_type = \
+ torch.randn(*shape), 0.1, 0, dtype
+ X = X.to(memory_format=memory_format)
+ num_parameter = 1 if num_param == 1 or len(shape) == 1 else shape[1]
+ W = torch.randn(num_parameter)
+ W, w_scale, w_zero_point = \
+ torch.randn(num_parameter), 0.2, 0
+
+ qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
+ dtype=torch_type)
+ dqX = qX.dequantize()
+ qW = torch.quantize_per_tensor(W, scale=w_scale, zero_point=w_zero_point,
+ dtype=torch_type)
+ dqW = qW.dequantize()
+
+ op = torch.nn.functional.prelu
+ qop = torch.ops.quantized.prelu
+ dqY = op(dqX, dqW)
+ qY = torch.quantize_per_tensor(dqY, scale=scale, zero_point=zero_point,
+ dtype=torch_type)
+ qY_hat = qop(qX, qW, scale, zero_point)
+ self.assertEqual(qY.dequantize(), qY_hat.dequantize(),
+ msg="F.prelu failed ({} vs {})".format(qY, qY_hat))
+
"""Tests the correctness of the quantized::qlayer_norm op."""
@skipIfNoFBGEMM
def test_qlayer_norm(self):
diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py
index 9498f34..f086a57 100644
--- a/test/quantization/fx/test_quantize_fx.py
+++ b/test/quantization/fx/test_quantize_fx.py
@@ -5782,6 +5782,28 @@
def test_leaky_relu(self):
self._test_activation_impl(nn.LeakyReLU, F.leaky_relu, nnq.LeakyReLU, torch.ops.quantized.leaky_relu)
+ def test_prelu(self):
+ class M(torch.nn.Module):
+ def __init__(self, num_param: int):
+ super(M, self).__init__()
+ self.op = torch.nn.PReLU(num_parameters=num_param)
+
+ def forward(self, input):
+ return self.op(input)
+
+ X = [[torch.randn(4, 4, 4, 4, dtype=torch.float)]]
+ options = itertools.product([1, 4], self.static_quant_types, [True, False])
+ quantized_nodes = {
+ # is_reference
+ True: ns.call_module(torch.nn.PReLU),
+ False: ns.call_module(torch.nn.quantized.PReLU),
+ }
+
+ for num_parameter, quant_type, is_reference in options:
+ self.checkGraphModeFxOp(
+ M(num_parameter), X, quant_type, quantized_nodes[is_reference],
+ is_reference=is_reference)
+
def _test_norm_impl(
self, float_module, float_op, op_args, data, quantized_module, quantized_op,
skip_op_arg_for_functional=False):
diff --git a/test/test_module_init.py b/test/test_module_init.py
index d7d139b..61a9a2f 100644
--- a/test/test_module_init.py
+++ b/test/test_module_init.py
@@ -226,6 +226,7 @@
'factory_kwargs': {},
}),
torch.nn.quantized.MaxPool2d: ((3,), {}),
+ torch.nn.quantized.PReLU: ((0.01, 0), {}),
torch.nn.quantized.Quantize: ((0.1, 0), {
'dtype': torch.int16,
'factory_kwargs': {},
diff --git a/torch/ao/ns/fx/mappings.py b/torch/ao/ns/fx/mappings.py
index fc53a24..7bcd0ba 100644
--- a/torch/ao/ns/fx/mappings.py
+++ b/torch/ao/ns/fx/mappings.py
@@ -310,6 +310,16 @@
set([
nn.Softmax,
]),
+ # PReLU
+ set([
+ nn.PReLU,
+ nnq.PReLU,
+ ]),
+ # F.prelu
+ set([
+ F.prelu,
+ toq.prelu,
+ ]),
]
# for each floating point op, add versions of the op added by
@@ -468,6 +478,7 @@
operator.mul,
torch.mul,
torch.sum,
+ F.prelu,
])
FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set()
@@ -488,6 +499,7 @@
toq.layer_norm,
toq.leaky_relu,
toq.dropout,
+ toq.prelu,
# TODO(future PR): implement shadowing for binary ops and
# uncomment below
# toq.add,
@@ -568,6 +580,7 @@
nn.SiLU,
nn.Mish,
nn.Softmax,
+ nn.PReLU,
nni.BNReLU2d,
nni.BNReLU3d,
nni.ConvReLU1d,
@@ -613,6 +626,7 @@
nnq.EmbeddingBag,
nnq.Dropout,
nnq.Softmax,
+ nnq.PReLU,
nniq.BNReLU2d,
nniq.BNReLU3d,
nniq.ConvReLU1d,
diff --git a/torch/ao/quantization/backend_config/native.py b/torch/ao/quantization/backend_config/native.py
index 8b8e764..bf259b8 100644
--- a/torch/ao/quantization/backend_config/native.py
+++ b/torch/ao/quantization/backend_config/native.py
@@ -106,6 +106,7 @@
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.Dropout,
+ torch.nn.PReLU,
torch.nn.functional.elu,
torch.nn.functional.hardswish,
torch.nn.functional.instance_norm,
diff --git a/torch/ao/quantization/fx/_lower_to_native_backend.py b/torch/ao/quantization/fx/_lower_to_native_backend.py
index 41fbb36..02fdc76 100644
--- a/torch/ao/quantization/fx/_lower_to_native_backend.py
+++ b/torch/ao/quantization/fx/_lower_to_native_backend.py
@@ -83,6 +83,7 @@
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.Dropout,
+ torch.nn.PReLU,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.intrinsic.BNReLU2d,
@@ -238,6 +239,7 @@
nn.LayerNorm: nnq.LayerNorm,
nn.Dropout: nnq.Dropout,
nn.Softmax: nnq.Softmax,
+ nn.PReLU: nnq.PReLU,
nni.BNReLU2d: nniq.BNReLU2d,
nni.BNReLU3d: nniq.BNReLU3d,
}
diff --git a/torch/ao/quantization/quantization_mappings.py b/torch/ao/quantization/quantization_mappings.py
index ebaa693..8b192b6 100644
--- a/torch/ao/quantization/quantization_mappings.py
+++ b/torch/ao/quantization/quantization_mappings.py
@@ -71,6 +71,7 @@
nn.Linear: nnq.Linear,
nn.ReLU6: nnq.ReLU6,
nn.Dropout: nnq.Dropout,
+ nn.PReLU: nnq.PReLU,
# Wrapper Modules:
nnq.FloatFunctional: nnq.QFunctional,
# Intrinsic modules:
diff --git a/torch/nn/quantized/modules/__init__.py b/torch/nn/quantized/modules/__init__.py
index 62cea10..2ccfe1d 100644
--- a/torch/nn/quantized/modules/__init__.py
+++ b/torch/nn/quantized/modules/__init__.py
@@ -1,7 +1,7 @@
import torch
from torch.nn.modules.pooling import MaxPool2d
-from .activation import ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid, Softmax, MultiheadAttention
+from .activation import ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid, Softmax, MultiheadAttention, PReLU
from .dropout import Dropout
from .batchnorm import BatchNorm2d, BatchNorm3d
from .normalization import LayerNorm, GroupNorm, InstanceNorm1d, \
@@ -118,6 +118,7 @@
'Sigmoid',
'Softmax',
'Dropout',
+ 'PReLU',
# Wrapper modules
'FloatFunctional',
'FXFloatFunctional',
diff --git a/torch/nn/quantized/modules/activation.py b/torch/nn/quantized/modules/activation.py
index beaae33..d1ce62b 100644
--- a/torch/nn/quantized/modules/activation.py
+++ b/torch/nn/quantized/modules/activation.py
@@ -225,3 +225,53 @@
setattr(converted, 'bias_v', bias_v) # noqa: B010
return converted
+
+class PReLU(torch.nn.Module):
+ r"""This is the quantized equivalent of :class:`~torch.nn.PReLU`.
+
+ Args:
+ scale: quantization scale of the output tensor
+ zero_point: quantization zero point of the output tensor
+ num_parameters: number of parameters: 1, or the number of channels at input. Default: 1
+ """
+ def __init__(self, output_scale: float, output_zero_point: int,
+ num_parameters: int = 1) -> None:
+ super().__init__()
+ self.num_parameters = num_parameters
+ self.scale = output_scale
+ self.zero_point = output_zero_point
+ w = torch.randn(num_parameters, dtype=torch.float)
+ qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.quint8)
+ self.set_weight(qw)
+
+ def set_weight(self, w: torch.Tensor) -> None:
+ self.weight = w
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return torch.ops.quantized.prelu(input, self.weight, self.scale, self.zero_point)
+
+ def _get_name(self):
+ return 'QuantizedPReLU'
+
+ @classmethod
+ def from_float(cls, mod):
+ scale, zero_point = mod.activation_post_process.calculate_qparams()
+ qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
+ float_wt = mod.weight.float()
+ observer = mod.qconfig.weight()
+ wt_scale, wt_zp = observer.calculate_qparams()
+ qweight = torch.quantize_per_tensor(
+ float_wt, float(wt_scale), int(wt_zp), torch.quint8)
+ qprelu.set_weight(qweight)
+ return qprelu
+
+ @classmethod
+ def from_reference(cls, mod, scale, zero_point):
+ qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
+ float_wt = mod.weight.float()
+ observer = mod.qconfig.weight()
+ wt_scale, wt_zp = observer.calculate_qparams()
+ qweight = torch.quantize_per_tensor(
+ float_wt, float(wt_scale), int(wt_zp), torch.quint8)
+ qprelu.set_weight(qweight)
+ return qprelu