Add prelu op and module for quantized CPU backend (#73491)

Add prelu op and module for quantized CPU backend.
The PR includes:
- Quantized version of prelu op
- Native prelu kernel for quantized CPU
- Prelu modules in `nn` and `nn.quantized`
- FX support for prelu
- Unit tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73491
Approved by: https://github.com/jerryzh168
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index f7dee89..022755d 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3987,6 +3987,7 @@
     MkldnnCPU: mkldnn_prelu
     CPU: prelu_cpu
     CUDA: prelu_cuda
+    QuantizedCPU: prelu_quantized_cpu
 
 - func: prelu_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor)
   variants: function, method
diff --git a/aten/src/ATen/native/quantized/cpu/QuantizedOps.h b/aten/src/ATen/native/quantized/cpu/QuantizedOps.h
index bfa1f1f..0391723 100644
--- a/aten/src/ATen/native/quantized/cpu/QuantizedOps.h
+++ b/aten/src/ATen/native/quantized/cpu/QuantizedOps.h
@@ -163,6 +163,9 @@
     double /* eps */,
     Tensor* /* Y */);
 
+using qprelu_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/,
+                           const Tensor& /*qw*/);
+
 DECLARE_DISPATCH(qadaptive_avg_pool2d_fn, qadaptive_avg_pool2d_nhwc_stub);
 DECLARE_DISPATCH(qadaptive_avg_pool3d_fn, qadaptive_avg_pool3d_ndhwc_stub);
 DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_relu_stub);
@@ -194,6 +197,7 @@
 DECLARE_DISPATCH(qthreshold_fn, qthreshold_stub);
 DECLARE_DISPATCH(qtopk_fn, qtopk_stub);
 DECLARE_DISPATCH(qupsample_bilinear2d_fn, qupsample_bilinear2d_nhwc_stub);
+DECLARE_DISPATCH(qprelu_fn, qprelu_stub);
 
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
index d16e4d6..fc9d7a8 100644
--- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
+++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
@@ -652,6 +652,81 @@
   });
 }
 
+static void qprelu_out_kernel(Tensor& out,
+                              const Tensor& qx,
+                              const Tensor& qw) {
+  int32_t i_zp = static_cast<int32_t>(qx.q_zero_point());
+  float i_scale = static_cast<float>(qx.q_scale());
+
+  int32_t w_zp = static_cast<int32_t>(qw.q_zero_point());
+  float w_scale = static_cast<float>(qw.q_scale());
+
+  int32_t o_zp = static_cast<int32_t>(out.q_zero_point());
+  float o_scale = static_cast<float>(out.q_scale());
+  float o_inv_scale = 1.0f / o_scale;
+
+  float multiplier = i_scale * w_scale * o_inv_scale;
+
+  int64_t input_ndim = qx.dim();
+  TORCH_CHECK(input_ndim > 0, "qprelu: zero-dim input tensor is not allowed.");
+
+  // Helper to convert 1d tensors or scalar tensor to an nd tensor that broadcasts with input
+  // All elements go into the channel dimension
+  DimVector sizes(input_ndim, 1), strides(input_ndim, 0);
+  auto as_nd = [&](const Tensor& t) {
+    TORCH_INTERNAL_ASSERT(t.defined() && (t.dim() == 1 || t.dim() == 0));
+    sizes[1] = t.dim() == 1 ? t.sizes()[0] : 1;
+    strides[1] = t.dim() == 1 ? t.strides()[0] : 0;
+    return t.as_strided(sizes, strides);
+  };
+
+  auto qw_nd = as_nd(qw);
+
+  auto iter = TensorIteratorConfig()
+    .add_output(out)
+    .add_input(qx)
+    .add_input(qw_nd)
+    .build();
+
+  AT_DISPATCH_QINT_TYPES(out.scalar_type(), "qprelu", [&] {
+    using qVec = Vectorized<scalar_t>;
+    qVec i_zp_vec = qVec(static_cast<scalar_t>(i_zp));
+    qVec w_zp_vec = qVec(static_cast<scalar_t>(w_zp));
+
+    // Quantized one as weight
+    auto qw_one = at::native::quantize_val<scalar_t>(w_scale, w_zp, 1.0f);
+    qVec vec_qw_one = qVec(qw_one);
+    auto vec_qw_one_sub_zp = vec_qw_one.widening_subtract(w_zp_vec)[0];
+    int32_t qw_one_sub_zp = qw_one.val_ - w_zp;
+
+    cpu_kernel_vec(
+      iter,
+      [=](scalar_t val_qx, scalar_t val_qw) -> scalar_t {
+        int32_t qx_pos = std::max(static_cast<int32_t>(val_qx.val_), i_zp);
+        int32_t qx_neg = std::min(static_cast<int32_t>(val_qx.val_), i_zp);
+        int32_t qx_pos_sub_zp = qx_pos - i_zp;
+        int32_t qx_neg_sub_zp = qx_neg - i_zp;
+        int32_t qw_sub_zp = val_qw.val_ - w_zp;
+        auto qy_sub_zp = qx_pos_sub_zp * qw_one_sub_zp + qx_neg_sub_zp * qw_sub_zp;
+        return at::native::requantize_from_int<scalar_t>(
+            multiplier, o_zp, qy_sub_zp);
+      },
+      [=](qVec vec_qx, qVec vec_qw) -> qVec {
+        auto vec_qx_pos = vec_qx.maximum(i_zp_vec);
+        auto vec_qx_neg = vec_qx.minimum(i_zp_vec);
+        qVec::int_vec_return_type qx_pos_sub_zp = vec_qx_pos.widening_subtract(i_zp_vec);
+        qVec::int_vec_return_type qx_neg_sub_zp = vec_qx_neg.widening_subtract(i_zp_vec);
+        qVec::int_vec_return_type qw_sub_zp = vec_qw.widening_subtract(w_zp_vec);
+        qVec::int_vec_return_type qy_sub_zp;
+        for (const auto i : c10::irange(qVec::int_num_vecs())) {
+          qy_sub_zp[i] = qx_pos_sub_zp[i] * vec_qw_one_sub_zp + qx_neg_sub_zp[i] * qw_sub_zp[i];
+        }
+        return qVec::requantize_from_int(qy_sub_zp, multiplier, o_zp);
+      });
+  });
+
+}
+
 void qgelu_kernel(const Tensor& qx, Tensor& qy, GeluType approximate) {
   int64_t zero_point = qx.q_zero_point();
   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
@@ -3694,6 +3769,7 @@
 REGISTER_NO_AVX512_DISPATCH(qmul_stub);
 REGISTER_NO_AVX512_DISPATCH(qrelu_leaky_stub);
 REGISTER_NO_AVX512_DISPATCH(qrelu_stub);
+REGISTER_NO_AVX512_DISPATCH(qprelu_stub);
 REGISTER_NO_AVX512_DISPATCH(qgelu_stub);
 REGISTER_NO_AVX512_DISPATCH(qsigmoid_stub);
 REGISTER_NO_AVX512_DISPATCH(qtanh_stub);
@@ -3748,6 +3824,7 @@
 REGISTER_DISPATCH(qmul_stub, &qmul_kernel<false>);
 REGISTER_DISPATCH(qrelu_leaky_stub, &leaky_qrelu_out_kernel);
 REGISTER_DISPATCH(qrelu_stub, &qrelu_kernel);
+REGISTER_DISPATCH(qprelu_stub, &qprelu_out_kernel);
 REGISTER_DISPATCH(qgelu_stub, &qgelu_kernel);
 REGISTER_DISPATCH(qsigmoid_stub, &qsigmoid_kernel);
 REGISTER_DISPATCH(qtanh_stub, &qtanh_kernel);
diff --git a/aten/src/ATen/native/quantized/cpu/qrelu.cpp b/aten/src/ATen/native/quantized/cpu/qrelu.cpp
index ad071d6..e4ca887 100644
--- a/aten/src/ATen/native/quantized/cpu/qrelu.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qrelu.cpp
@@ -17,6 +17,7 @@
 
 DEFINE_DISPATCH(qrelu_stub);
 DEFINE_DISPATCH(qrelu_leaky_stub);
+DEFINE_DISPATCH(qprelu_stub);
 
 #ifdef USE_PYTORCH_QNNPACK
 Tensor qnnpack_relu(Tensor input) {
@@ -134,6 +135,32 @@
   return self;
 }
 
+Tensor prelu_quantized_cpu_impl(const Tensor& self, const Tensor& weight,
+                                double output_scale, int64_t output_zero_point) {
+  auto ndim = self.dim();
+  // for ndim < 1 or > 5, go to reference path
+  if (ndim > 5 || ndim < 1) {
+    auto x = self.dequantize();
+    auto y = at::prelu(x, weight);
+    return at::quantize_per_tensor(y, output_scale, output_zero_point, c10::kQUInt8);
+  }
+
+  auto qy = at::_empty_affine_quantized(self.sizes(),
+      at::device(kCPU)
+        .dtype(self.scalar_type()),
+      output_scale,
+      output_zero_point,
+      self.suggest_memory_format());
+
+  qprelu_stub(self.device().type(), qy, self, weight);
+
+  return qy;
+}
+
+Tensor prelu_quantized_cpu(const Tensor& self, const Tensor& weight) {
+  return prelu_quantized_cpu_impl(self, weight, self.q_scale(), self.q_zero_point());
+}
+
 namespace {
 Tensor quantized_relu6(const Tensor& qx) {
   Tensor qy;
@@ -175,9 +202,17 @@
   }
 };
 
+class QPRelu final {
+ public:
+  static Tensor run(Tensor self, const Tensor& weight, double output_scale, int64_t output_zero_point) {
+  return prelu_quantized_cpu_impl(self, weight, output_scale, output_zero_point);
+  }
+};
+
 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
   m.impl(TORCH_SELECTIVE_NAME("quantized::relu6"), TORCH_FN(QRelu6::run));
   m.impl(TORCH_SELECTIVE_NAME("quantized::leaky_relu"), TORCH_FN(QLeakyRelu::run));
+  m.impl(TORCH_SELECTIVE_NAME("quantized::prelu"), TORCH_FN(QPRelu::run));
 }
 
 } // namespace
diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp
index ea5338d..a6ac4b3 100644
--- a/aten/src/ATen/native/quantized/library.cpp
+++ b/aten/src/ATen/native/quantized/library.cpp
@@ -187,6 +187,7 @@
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::max_pool2d(Tensor qx, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::relu6(Tensor qx, bool inplace=False) -> Tensor"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::leaky_relu(Tensor qx, Scalar negative_slope, bool inplace, float output_scale, int output_zero_point) -> Tensor"));
+  m.def(TORCH_SELECTIVE_SCHEMA("quantized::prelu(Tensor qx, Tensor weight, float output_scale, int output_zero_point) -> Tensor"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::sigmoid(Tensor qx, float output_scale, int output_zero_point) -> Tensor"));
   m.def(TORCH_SELECTIVE_SCHEMA("quantized::softmax(Tensor qx, int dim, float output_scale, int output_zero_point) -> Tensor"));
 }
diff --git a/test/quantization/core/test_quantized_module.py b/test/quantization/core/test_quantized_module.py
index 7cbab3b..067b7b4 100644
--- a/test/quantization/core/test_quantized_module.py
+++ b/test/quantization/core/test_quantized_module.py
@@ -1011,6 +1011,30 @@
             self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices,
                                              offsets, set_qconfig, is_emb_bag=True, dtype=qdtype)
 
+    def test_prelu(self):
+        x = torch.randn((4, 4, 4, 4), dtype=torch.float)
+        qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.quint8)
+
+        # num_parameters = 1
+        prelu_module = nnq.PReLU(output_scale=1.0, output_zero_point=0, num_parameters=1)
+        w = torch.randn(1, dtype=torch.float)
+        qw = torch.quantize_per_tensor(w, 1.0, 0, dtype=torch.quint8)
+        prelu_module.set_weight(qw)
+        qy = prelu_module(qx)
+        qy_ref = torch.prelu(qx, qw)
+
+        self.assertEqual(qy_ref, qy,
+                         msg="PReLU module API failed")
+
+        # num_parameters = num_channels
+        prelu_module = nnq.PReLU(output_scale=1.0, output_zero_point=0, num_parameters=4)
+        w = torch.randn(4, dtype=torch.float)
+        qw = torch.quantize_per_tensor(w, 1.0, 0, dtype=torch.quint8)
+        prelu_module.set_weight(qw)
+        qy = prelu_module(qx)
+        qy_ref = torch.prelu(qx, qw)
+        self.assertEqual(qy_ref, qy,
+                         msg="PReLU module API failed")
 
 class TestDynamicQuantizedModule(QuantizationTestCase):
     def _test_qconv_impl(self, q_mod, dq_mod, dim, dtype, bias):
diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py
index 51d7473..4d9842b 100644
--- a/test/quantization/core/test_quantized_op.py
+++ b/test/quantization/core/test_quantized_op.py
@@ -469,6 +469,40 @@
                 self.assertEqual(qY.dequantize(), qY_hat.dequantize(),
                                  msg="F.gelu failed ({} vs {})".format(qY, qY_hat))
 
+    """Tests the correctness of the quantized::prelu op."""
+    def test_qprelu(self):
+        shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4))
+        num_params = (0, 1)  # 0: num_parameter = num_channels
+        dtypes = (torch.quint8, torch.qint8)
+        memory_formats = (torch.channels_last, torch.contiguous_format)
+        test_cases = itertools.product(shapes, num_params, dtypes, memory_formats)
+        for shape, num_param, dtype, memory_format in test_cases:
+            if memory_format == torch.channels_last and len(shape) != 4:
+                continue
+            X, scale, zero_point, torch_type = \
+                torch.randn(*shape), 0.1, 0, dtype
+            X = X.to(memory_format=memory_format)
+            num_parameter = 1 if num_param == 1 or len(shape) == 1 else shape[1]
+            W = torch.randn(num_parameter)
+            W, w_scale, w_zero_point = \
+                torch.randn(num_parameter), 0.2, 0
+
+            qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
+                                           dtype=torch_type)
+            dqX = qX.dequantize()
+            qW = torch.quantize_per_tensor(W, scale=w_scale, zero_point=w_zero_point,
+                                           dtype=torch_type)
+            dqW = qW.dequantize()
+
+            op = torch.nn.functional.prelu
+            qop = torch.ops.quantized.prelu
+            dqY = op(dqX, dqW)
+            qY = torch.quantize_per_tensor(dqY, scale=scale, zero_point=zero_point,
+                                           dtype=torch_type)
+            qY_hat = qop(qX, qW, scale, zero_point)
+            self.assertEqual(qY.dequantize(), qY_hat.dequantize(),
+                             msg="F.prelu failed ({} vs {})".format(qY, qY_hat))
+
     """Tests the correctness of the quantized::qlayer_norm op."""
     @skipIfNoFBGEMM
     def test_qlayer_norm(self):
diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py
index 9498f34..f086a57 100644
--- a/test/quantization/fx/test_quantize_fx.py
+++ b/test/quantization/fx/test_quantize_fx.py
@@ -5782,6 +5782,28 @@
     def test_leaky_relu(self):
         self._test_activation_impl(nn.LeakyReLU, F.leaky_relu, nnq.LeakyReLU, torch.ops.quantized.leaky_relu)
 
+    def test_prelu(self):
+        class M(torch.nn.Module):
+            def __init__(self, num_param: int):
+                super(M, self).__init__()
+                self.op = torch.nn.PReLU(num_parameters=num_param)
+
+            def forward(self, input):
+                return self.op(input)
+
+        X = [[torch.randn(4, 4, 4, 4, dtype=torch.float)]]
+        options = itertools.product([1, 4], self.static_quant_types, [True, False])
+        quantized_nodes = {
+            # is_reference
+            True: ns.call_module(torch.nn.PReLU),
+            False: ns.call_module(torch.nn.quantized.PReLU),
+        }
+
+        for num_parameter, quant_type, is_reference in options:
+            self.checkGraphModeFxOp(
+                M(num_parameter), X, quant_type, quantized_nodes[is_reference],
+                is_reference=is_reference)
+
     def _test_norm_impl(
             self, float_module, float_op, op_args, data, quantized_module, quantized_op,
             skip_op_arg_for_functional=False):
diff --git a/test/test_module_init.py b/test/test_module_init.py
index d7d139b..61a9a2f 100644
--- a/test/test_module_init.py
+++ b/test/test_module_init.py
@@ -226,6 +226,7 @@
             'factory_kwargs': {},
         }),
         torch.nn.quantized.MaxPool2d: ((3,), {}),
+        torch.nn.quantized.PReLU: ((0.01, 0), {}),
         torch.nn.quantized.Quantize: ((0.1, 0), {
             'dtype': torch.int16,
             'factory_kwargs': {},
diff --git a/torch/ao/ns/fx/mappings.py b/torch/ao/ns/fx/mappings.py
index fc53a24..7bcd0ba 100644
--- a/torch/ao/ns/fx/mappings.py
+++ b/torch/ao/ns/fx/mappings.py
@@ -310,6 +310,16 @@
         set([
             nn.Softmax,
         ]),
+        # PReLU
+        set([
+            nn.PReLU,
+            nnq.PReLU,
+        ]),
+        # F.prelu
+        set([
+            F.prelu,
+            toq.prelu,
+        ]),
     ]
 
     # for each floating point op, add versions of the op added by
@@ -468,6 +478,7 @@
         operator.mul,
         torch.mul,
         torch.sum,
+        F.prelu,
     ])
 
     FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set()
@@ -488,6 +499,7 @@
         toq.layer_norm,
         toq.leaky_relu,
         toq.dropout,
+        toq.prelu,
         # TODO(future PR): implement shadowing for binary ops and
         # uncomment below
         # toq.add,
@@ -568,6 +580,7 @@
         nn.SiLU,
         nn.Mish,
         nn.Softmax,
+        nn.PReLU,
         nni.BNReLU2d,
         nni.BNReLU3d,
         nni.ConvReLU1d,
@@ -613,6 +626,7 @@
         nnq.EmbeddingBag,
         nnq.Dropout,
         nnq.Softmax,
+        nnq.PReLU,
         nniq.BNReLU2d,
         nniq.BNReLU3d,
         nniq.ConvReLU1d,
diff --git a/torch/ao/quantization/backend_config/native.py b/torch/ao/quantization/backend_config/native.py
index 8b8e764..bf259b8 100644
--- a/torch/ao/quantization/backend_config/native.py
+++ b/torch/ao/quantization/backend_config/native.py
@@ -106,6 +106,7 @@
         torch.nn.InstanceNorm3d,
         torch.nn.LayerNorm,
         torch.nn.Dropout,
+        torch.nn.PReLU,
         torch.nn.functional.elu,
         torch.nn.functional.hardswish,
         torch.nn.functional.instance_norm,
diff --git a/torch/ao/quantization/fx/_lower_to_native_backend.py b/torch/ao/quantization/fx/_lower_to_native_backend.py
index 41fbb36..02fdc76 100644
--- a/torch/ao/quantization/fx/_lower_to_native_backend.py
+++ b/torch/ao/quantization/fx/_lower_to_native_backend.py
@@ -83,6 +83,7 @@
         torch.nn.InstanceNorm3d,
         torch.nn.LayerNorm,
         torch.nn.Dropout,
+        torch.nn.PReLU,
         torch.nn.BatchNorm2d,
         torch.nn.BatchNorm3d,
         torch.nn.intrinsic.BNReLU2d,
@@ -238,6 +239,7 @@
     nn.LayerNorm: nnq.LayerNorm,
     nn.Dropout: nnq.Dropout,
     nn.Softmax: nnq.Softmax,
+    nn.PReLU: nnq.PReLU,
     nni.BNReLU2d: nniq.BNReLU2d,
     nni.BNReLU3d: nniq.BNReLU3d,
 }
diff --git a/torch/ao/quantization/quantization_mappings.py b/torch/ao/quantization/quantization_mappings.py
index ebaa693..8b192b6 100644
--- a/torch/ao/quantization/quantization_mappings.py
+++ b/torch/ao/quantization/quantization_mappings.py
@@ -71,6 +71,7 @@
     nn.Linear: nnq.Linear,
     nn.ReLU6: nnq.ReLU6,
     nn.Dropout: nnq.Dropout,
+    nn.PReLU: nnq.PReLU,
     # Wrapper Modules:
     nnq.FloatFunctional: nnq.QFunctional,
     # Intrinsic modules:
diff --git a/torch/nn/quantized/modules/__init__.py b/torch/nn/quantized/modules/__init__.py
index 62cea10..2ccfe1d 100644
--- a/torch/nn/quantized/modules/__init__.py
+++ b/torch/nn/quantized/modules/__init__.py
@@ -1,7 +1,7 @@
 import torch
 from torch.nn.modules.pooling import MaxPool2d
 
-from .activation import ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid, Softmax, MultiheadAttention
+from .activation import ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid, Softmax, MultiheadAttention, PReLU
 from .dropout import Dropout
 from .batchnorm import BatchNorm2d, BatchNorm3d
 from .normalization import LayerNorm, GroupNorm, InstanceNorm1d, \
@@ -118,6 +118,7 @@
     'Sigmoid',
     'Softmax',
     'Dropout',
+    'PReLU',
     # Wrapper modules
     'FloatFunctional',
     'FXFloatFunctional',
diff --git a/torch/nn/quantized/modules/activation.py b/torch/nn/quantized/modules/activation.py
index beaae33..d1ce62b 100644
--- a/torch/nn/quantized/modules/activation.py
+++ b/torch/nn/quantized/modules/activation.py
@@ -225,3 +225,53 @@
             setattr(converted, 'bias_v', bias_v)  # noqa: B010
 
         return converted
+
+class PReLU(torch.nn.Module):
+    r"""This is the quantized equivalent of :class:`~torch.nn.PReLU`.
+
+    Args:
+        scale: quantization scale of the output tensor
+        zero_point: quantization zero point of the output tensor
+        num_parameters: number of parameters: 1, or the number of channels at input. Default: 1
+    """
+    def __init__(self, output_scale: float, output_zero_point: int,
+                 num_parameters: int = 1) -> None:
+        super().__init__()
+        self.num_parameters = num_parameters
+        self.scale = output_scale
+        self.zero_point = output_zero_point
+        w = torch.randn(num_parameters, dtype=torch.float)
+        qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.quint8)
+        self.set_weight(qw)
+
+    def set_weight(self, w: torch.Tensor) -> None:
+        self.weight = w
+
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        return torch.ops.quantized.prelu(input, self.weight, self.scale, self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedPReLU'
+
+    @classmethod
+    def from_float(cls, mod):
+        scale, zero_point = mod.activation_post_process.calculate_qparams()
+        qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
+        float_wt = mod.weight.float()
+        observer = mod.qconfig.weight()
+        wt_scale, wt_zp = observer.calculate_qparams()
+        qweight = torch.quantize_per_tensor(
+            float_wt, float(wt_scale), int(wt_zp), torch.quint8)
+        qprelu.set_weight(qweight)
+        return qprelu
+
+    @classmethod
+    def from_reference(cls, mod, scale, zero_point):
+        qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
+        float_wt = mod.weight.float()
+        observer = mod.qconfig.weight()
+        wt_scale, wt_zp = observer.calculate_qparams()
+        qweight = torch.quantize_per_tensor(
+            float_wt, float(wt_scale), int(wt_zp), torch.quint8)
+        qprelu.set_weight(qweight)
+        return qprelu