Back out "Revert D20229168: [quantization] Use torchbind for Linear PackedParams" (#38101)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38101

Original commit changeset: 29e8a4d3b8bf
ghstack-source-id: 103730417

Test Plan: waitforsadcastle

Differential Revision: D21471381

fbshipit-source-id: a922cdf31ba32021e7264ae1454c646c0bfd7ef4
diff --git a/aten/src/ATen/native/QuantizedLinear.cpp b/aten/src/ATen/native/QuantizedLinear.cpp
index 56b4f25..cce2b0b 100644
--- a/aten/src/ATen/native/QuantizedLinear.cpp
+++ b/aten/src/ATen/native/QuantizedLinear.cpp
@@ -12,6 +12,8 @@
 #include <ATen/Parallel.h>
 #include <ATen/WrapDimUtilsMulti.h>
 #include <ATen/cpp_custom_type_hack.h>
+#include <ATen/native/quantized/cpu/fbgemm_utils.h>
+#include <ATen/native/quantized/cpu/packed_params.h>
 
 #ifdef USE_FBGEMM
 #include <fbgemm/Fbgemm.h>
@@ -19,11 +21,15 @@
 #include <fbgemm/QuantUtils.h>
 #endif // USE_FBGEMM
 
+namespace caffe2 {
+CAFFE_KNOWN_TYPE(c10::intrusive_ptr<LinearPackedParamsBase>);
+} // namespace caffe2
+
 #ifdef USE_FBGEMM
 namespace caffe2 {
 // Required for cpp_custom_type_hack to work
 CAFFE_KNOWN_TYPE(fbgemm::PackBMatrix<int8_t>);
-CAFFE_KNOWN_TYPE(fbgemm::PackedGemmMatrixFP16);
+CAFFE_KNOWN_TYPE(c10::intrusive_ptr<PackedLinearWeightFp16>);
 } // namespace caffe2
 #endif // USE_FBGEMM
 
@@ -360,7 +366,12 @@
   // flows across dll boundaries.
   auto ptr = std::make_unique<fbgemm::PackedGemmMatrixFP16>(
       fbgemm::matrix_op_t::Transpose, K, N, 1, weight_contig_ptr);
-  return cpp_custom_type_hack::create(std::move(ptr), weight.options());
+  c10::intrusive_ptr<LinearPackedParamsBase> packed_weight =
+      c10::make_intrusive<PackedLinearWeightFp16>(std::move(ptr), c10::nullopt);
+  auto unique_ptr_wrapper =
+      std::make_unique<decltype(packed_weight)>(std::move(packed_weight));
+  return cpp_custom_type_hack::create(
+      std::move(unique_ptr_wrapper), weight.options());
 }
 
 Tensor fbgemm_linear_fp16_weight_fp32_activation(
@@ -377,7 +388,10 @@
 
   // Pull out the PackedGemmMatrixFP16 instance from the owning tensor
   const fbgemm::PackedGemmMatrixFP16& packed_weight_fp16 =
-      cpp_custom_type_hack::cast<fbgemm::PackedGemmMatrixFP16>(packed_weight);
+      *c10::dynamic_intrusive_pointer_cast<PackedLinearWeightFp16>(
+           cpp_custom_type_hack::cast<
+               c10::intrusive_ptr<LinearPackedParamsBase>>(packed_weight))
+           ->w;
 
   TORCH_CHECK(input.size(input.dim() - 1) == packed_weight_fp16.numRows())
   TORCH_CHECK(input.dim() >= 2);
diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp
index d909cee..6019929 100644
--- a/aten/src/ATen/native/RNN.cpp
+++ b/aten/src/ATen/native/RNN.cpp
@@ -5,10 +5,13 @@
 #include <ATen/core/op_registration/op_registration.h>
 #include <ATen/cpp_custom_type_hack.h>
 #include <ATen/native/c10_utils.h>
+#include <ATen/native/quantized/cpu/packed_params.h>
 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
 #include <ATen/native/quantized/cpu/qnnpack_utils.h>
 #include <torch/custom_class.h>
 
+torch::jit::class_<LinearPackedParamsBase> register_linear_params();
+
 namespace at { namespace native {
 
 namespace {
@@ -58,7 +61,8 @@
     std::string,
     std::vector<at::Tensor>,
     std::vector<double>,
-    std::vector<int64_t>>;
+    std::vector<int64_t>,
+    std::vector<c10::intrusive_ptr<LinearPackedParamsBase>>>;
 
 // Base class so we can polymorphically handle these
 struct CellParamsBase : torch::CustomClassHolder {
@@ -193,14 +197,16 @@
         "quantized",
         std::move(tensors_to_serialize),
         std::move(doubles_to_serialize),
-        std::move(longs_to_serialize));
+        std::move(longs_to_serialize),
+        {});
   }
   static c10::intrusive_ptr<CellParamsBase> __setstate__(
       CellParamsSerializationType state) {
     std::vector<at::Tensor> tensors;
     std::vector<double> doubles;
     std::vector<int64_t> longs;
-    std::tie(std::ignore, tensors, doubles, longs) = std::move(state);
+    std::tie(std::ignore, tensors, doubles, longs, std::ignore) =
+        std::move(state);
     TORCH_INTERNAL_ASSERT(tensors.size() == 6);
     TORCH_INTERNAL_ASSERT(doubles.size() == 2);
     TORCH_INTERNAL_ASSERT(longs.size() == 2);
@@ -278,15 +284,17 @@
 // aten/src/ATen/native/quantized/cpu/fbgemm_utils.h.
 
 c10::intrusive_ptr<CellParamsBase> make_quantized_cell_params_dynamic(
-    at::Tensor w_ih_packed,
-    at::Tensor w_hh_packed,
+    c10::intrusive_ptr<LinearPackedParamsBase> w_ih_packed,
+    c10::intrusive_ptr<LinearPackedParamsBase> w_hh_packed,
     at::Tensor bias_ih,
     at::Tensor bias_hh);
 
 struct QuantizedCellParamsDynamic : public CellParamsBase {
   QuantizedCellParamsDynamic(
-      Tensor _packed_w_ih, /* Prepacked Weight Tensor */
-      Tensor _packed_w_hh, /* Prepacked Weight Tensor */
+      c10::intrusive_ptr<LinearPackedParamsBase>
+          _packed_w_ih, /* Prepacked Weight Tensor */
+      c10::intrusive_ptr<LinearPackedParamsBase>
+          _packed_w_hh, /* Prepacked Weight Tensor */
       Tensor _b_ih, /* float Bias Tensor */
       Tensor _b_hh /* float Bias Tensor */)
       : packed_w_ih(std::move(_packed_w_ih)),
@@ -294,8 +302,8 @@
         b_ih_(std::move(_b_ih)),
         b_hh_(std::move(_b_hh)) {}
 
-  const Tensor packed_w_ih;
-  const Tensor packed_w_hh;
+  c10::intrusive_ptr<LinearPackedParamsBase> packed_w_ih;
+  c10::intrusive_ptr<LinearPackedParamsBase> packed_w_hh;
   const Tensor b_ih_;
   const Tensor b_hh_;
 
@@ -307,26 +315,10 @@
   }
 
   Tensor linear_ih(const Tensor& input_ih) const override {
-    const auto kFuncName = "quantized::linear_dynamic";
-    const auto kOvrldName = "";
-    const std::vector<c10::IValue> output_ih_list =
-        callOp(kFuncName, kOvrldName, input_ih, packed_w_ih);
-    TORCH_INTERNAL_ASSERT(
-        output_ih_list.size() == 1,
-        "The output vector should have exact one element");
-    const Tensor output_ih = output_ih_list[0].toTensor();
-    return output_ih;
+    return packed_w_ih->apply_dynamic(input_ih);
   }
   Tensor linear_hh(const Tensor& input_hh) const override {
-    const auto kFuncName = "quantized::linear_dynamic";
-    const auto kOvrldName = "";
-    const std::vector<c10::IValue> output_hh_list =
-        callOp(kFuncName, kOvrldName, input_hh, packed_w_hh);
-    TORCH_INTERNAL_ASSERT(
-        output_hh_list.size() == 1,
-        "The output vector should have exact one element");
-    const Tensor output_hh = output_hh_list[0].toTensor();
-    return output_hh;
+    return packed_w_hh->apply_dynamic(input_hh);
   }
 
   const Tensor& b_ih() const override {
@@ -338,58 +330,44 @@
   CellParamsSerializationType __getstate__() const override {
     // Boxed dispatch nonsense
     // This will be cleaned up in the subsequent PR
-    auto unpacked_ih = callOp("quantized::linear_unpack", "", packed_w_ih);
-    TORCH_INTERNAL_ASSERT(unpacked_ih.size() == 2);
-    auto unpacked_hh = callOp("quantized::linear_unpack", "", packed_w_hh);
-    TORCH_INTERNAL_ASSERT(unpacked_hh.size() == 2);
+    auto unpacked_ih = packed_w_ih->unpack();
+    auto unpacked_hh = packed_w_hh->unpack();
 
     std::vector<at::Tensor> tensors_to_serialize{
-        /*w_ih=*/std::move(unpacked_ih[0]).toTensor(),
-        /*w_hh=*/std::move(unpacked_hh[0]).toTensor(),
         /*b_ih=*/b_ih_,
         /*b_hh=*/b_hh_,
     };
 
+    std::vector<c10::intrusive_ptr<LinearPackedParamsBase>>
+        packed_params_to_serialize{packed_w_ih, packed_w_hh};
+
     return CellParamsSerializationType(
-        "quantized_dynamic", std::move(tensors_to_serialize), {}, {});
+        "quantized_dynamic",
+        std::move(tensors_to_serialize),
+        {},
+        {},
+        std::move(packed_params_to_serialize));
   }
   static c10::intrusive_ptr<CellParamsBase> __setstate__(
       CellParamsSerializationType state) {
     std::vector<at::Tensor> tensors;
-    std::vector<double> doubles;
-    std::vector<int64_t> longs;
-    std::tie(std::ignore, tensors, doubles, longs) = std::move(state);
-    TORCH_INTERNAL_ASSERT(tensors.size() == 4);
-
-    at::Tensor b_ih = std::move(tensors[2]);
-    at::Tensor b_hh = std::move(tensors[3]);
-
-    // Boxed dispatch nonsense
-    // This will be cleaned up in the subsequent PR
-    auto packed_ih = callOp(
-        "quantized::linear_prepack",
-        "",
-        /*w_ih=*/std::move(tensors[0]),
-        /*b_ih=*/b_ih);
-    TORCH_INTERNAL_ASSERT(packed_ih.size() == 1);
-    auto packed_hh = callOp(
-        "quantized::linear_prepack",
-        "",
-        /*w_hh=*/std::move(tensors[1]),
-        /*b_hh=*/b_hh);
-    TORCH_INTERNAL_ASSERT(packed_hh.size() == 1);
+    std::vector<c10::intrusive_ptr<LinearPackedParamsBase>> packed_params;
+    std::tie(std::ignore, tensors, std::ignore, std::ignore, packed_params) =
+        std::move(state);
+    TORCH_INTERNAL_ASSERT(tensors.size() == 2);
+    TORCH_INTERNAL_ASSERT(packed_params.size() == 2);
 
     return make_quantized_cell_params_dynamic(
-        /*w_ih_packed=*/std::move(packed_ih[0]).toTensor(),
-        /*w_hh_packed=*/std::move(packed_hh[0]).toTensor(),
-        /*bias_ih=*/std::move(b_ih),
-        /*bias_hh=*/std::move(b_hh));
+        /*w_ih_packed=*/std::move(packed_params[0]),
+        /*w_hh_packed=*/std::move(packed_params[1]),
+        /*bias_ih=*/std::move(tensors[0]),
+        /*bias_hh=*/std::move(tensors[1]));
   }
 };
 
 c10::intrusive_ptr<CellParamsBase> make_quantized_cell_params_dynamic(
-    at::Tensor w_ih_packed,
-    at::Tensor w_hh_packed,
+    c10::intrusive_ptr<LinearPackedParamsBase> w_ih_packed,
+    c10::intrusive_ptr<LinearPackedParamsBase> w_hh_packed,
     at::Tensor bias_ih,
     at::Tensor bias_hh) {
   return c10::make_intrusive<QuantizedCellParamsDynamic>(
@@ -400,24 +378,17 @@
 }
 
 c10::intrusive_ptr<CellParamsBase> make_quantized_cell_params_fp16(
-    at::Tensor w_ih_packed,
-    at::Tensor w_hh_packed,
-    at::Tensor b_ih,
-    at::Tensor b_hh);
+    c10::intrusive_ptr<LinearPackedParamsBase> w_ih_packed,
+    c10::intrusive_ptr<LinearPackedParamsBase> w_hh_packed);
 
 struct QuantizedCellParamsFP16 : public CellParamsBase {
   QuantizedCellParamsFP16(
-      Tensor _packed_ih,
-      Tensor _packed_hh,
-      Tensor _b_ih,
-      Tensor _b_hh)
-      : packed_ih(std::move(_packed_ih)),
-        packed_hh(std::move(_packed_hh)),
-        b_ih_(std::move(_b_ih)),
-        b_hh_(std::move(_b_hh)) {}
+      c10::intrusive_ptr<LinearPackedParamsBase> _packed_ih,
+      c10::intrusive_ptr<LinearPackedParamsBase> _packed_hh)
+      : packed_ih(std::move(_packed_ih)), packed_hh(std::move(_packed_hh)) {}
 
-  const Tensor packed_ih;
-  const Tensor packed_hh;
+  c10::intrusive_ptr<LinearPackedParamsBase> packed_ih;
+  c10::intrusive_ptr<LinearPackedParamsBase> packed_hh;
   const Tensor b_ih_;
   const Tensor b_hh_;
 
@@ -427,37 +398,11 @@
   Tensor matmul_hh(const Tensor& /* unused */) const override {
     TORCH_CHECK(false, "matmul is not supported with quantized cell params");
   }
-  Tensor linear_common(
-      const Tensor& input,
-      const Tensor& packed_weight,
-      const Tensor& bias) const {
-#ifdef USE_FBGEMM
-    // Stupid hack because somehow we ended up with two separate
-    // FBGEMM packed fp16 weight formats in the system. Remove when
-    // we kill one of them.
-    if (cpp_custom_type_hack::isa<fbgemm::PackedGemmMatrixFP16>(
-            packed_weight)) {
-      return at::native::fbgemm_linear_fp16_weight_fp32_activation(
-          input, packed_weight, bias);
-    }
-#endif // USE_FBGEMM
-
-    const auto kFuncName = "quantized::linear_dynamic_fp16";
-    const auto kOvrldName = "";
-    const std::vector<c10::IValue> output_list =
-        callOp(kFuncName, kOvrldName, input, packed_weight);
-    TORCH_INTERNAL_ASSERT(
-        output_list.size() == 1,
-        "The output vector should have exact one element");
-    const Tensor output = output_list[0].toTensor();
-    return output;
-    TORCH_INTERNAL_ASSERT(false);
-  }
   Tensor linear_ih(const Tensor& input) const override {
-    return linear_common(input, packed_ih, b_ih_);
+    return packed_ih->apply_dynamic(input);
   }
   Tensor linear_hh(const Tensor& h) const override {
-    return linear_common(h, packed_hh, b_hh_);
+    return packed_hh->apply_dynamic(h);
   }
 
   const Tensor& b_ih() const override {
@@ -467,64 +412,31 @@
     return b_hh_;
   }
   CellParamsSerializationType __getstate__() const override {
-    // Boxed dispatch nonsense
-    // This will be cleaned up in the subsequent PR
-    auto unpacked_ih = callOp("quantized::linear_unpack_fp16", "", packed_ih);
-    TORCH_INTERNAL_ASSERT(unpacked_ih.size() == 2);
-    auto unpacked_hh = callOp("quantized::linear_unpack_fp16", "", packed_hh);
-    TORCH_INTERNAL_ASSERT(unpacked_hh.size() == 2);
-
-    std::vector<at::Tensor> tensors_to_serialize{
-        /*w_ih=*/std::move(unpacked_ih[0]).toTensor(),
-        /*w_hh=*/std::move(unpacked_hh[0]).toTensor(),
-        /*b_ih=*/b_ih_,
-        /*b_hh=*/b_hh_};
+    std::vector<c10::intrusive_ptr<LinearPackedParamsBase>>
+        packed_params_to_serialize{packed_ih, packed_hh};
 
     return CellParamsSerializationType(
-        "quantized_fp16", std::move(tensors_to_serialize), {}, {});
+        "quantized_fp16", {}, {}, {}, std::move(packed_params_to_serialize));
   }
   static c10::intrusive_ptr<CellParamsBase> __setstate__(
       CellParamsSerializationType state) {
-    std::string type;
-    std::vector<at::Tensor> tensors;
-    std::vector<double> doubles;
-    std::vector<int64_t> longs;
-    std::tie(type, tensors, doubles, longs) = std::move(state);
-    TORCH_INTERNAL_ASSERT(tensors.size() == 4);
-
-    // Boxed dispatch nonsense
-    // This will be cleaned up in the subsequent PR
-    auto packed_ih = callOp(
-        "quantized::linear_prepack_fp16",
-        "",
-        /*w_ih=*/std::move(tensors[0]),
-        /*b_ih=*/tensors[2]);
-    TORCH_INTERNAL_ASSERT(packed_ih.size() == 1);
-    auto packed_hh = callOp(
-        "quantized::linear_prepack_fp16",
-        "",
-        /*w_hh=*/std::move(tensors[1]),
-        /*b_hh=*/tensors[3]);
-    TORCH_INTERNAL_ASSERT(packed_hh.size() == 1);
+    std::vector<c10::intrusive_ptr<LinearPackedParamsBase>> packed_params;
+    std::tie(
+        std::ignore, std::ignore, std::ignore, std::ignore, packed_params) =
+        std::move(state);
+    TORCH_INTERNAL_ASSERT(packed_params.size() == 2);
 
     return make_quantized_cell_params_fp16(
-        /*w_ih_packed=*/std::move(packed_ih[0]).toTensor(),
-        /*w_hh_packed=*/std::move(packed_hh[0]).toTensor(),
-        /*b_ih=*/std::move(tensors[2]),
-        /*b_hh=*/std::move(tensors[3]));
+        /*w_ih_packed=*/std::move(packed_params[0]),
+        /*w_hh_packed=*/std::move(packed_params[1]));
   }
 };
 
 c10::intrusive_ptr<CellParamsBase> make_quantized_cell_params_fp16(
-    at::Tensor w_ih_packed,
-    at::Tensor w_hh_packed,
-    at::Tensor b_ih,
-    at::Tensor b_hh) {
+    c10::intrusive_ptr<LinearPackedParamsBase> w_ih_packed,
+    c10::intrusive_ptr<LinearPackedParamsBase> w_hh_packed) {
   return c10::make_intrusive<QuantizedCellParamsFP16>(
-      std::move(w_ih_packed),
-      std::move(w_hh_packed),
-      std::move(b_ih),
-      std::move(b_hh));
+      std::move(w_ih_packed), std::move(w_hh_packed));
 }
 
 static std::unordered_map<
@@ -630,65 +542,27 @@
   return c10::List<c10::intrusive_ptr<CellParamsBase>>(result);
 }
 
-static std::vector<c10::intrusive_ptr<CellParamsBase>> _quantized_params_dynamic(
-    c10::List<at::Tensor> params,
-    std::string qengine) {
+static c10::List<c10::intrusive_ptr<CellParamsBase>>
+gather_quantized_params_dynamic(c10::List<at::Tensor> params) {
   static at::Tensor undefined;
   std::vector<c10::intrusive_ptr<CellParamsBase>> result;
   for (size_t i = 0; i < params.size(); i += 2) {
-    at::Tensor bias_ih, bias_hh;
+    auto packed_struct_ih =
+        cpp_custom_type_hack::cast<c10::intrusive_ptr<LinearPackedParamsBase>>(
+            static_cast<at::Tensor>(params[i]));
+    auto packed_struct_hh =
+        cpp_custom_type_hack::cast<c10::intrusive_ptr<LinearPackedParamsBase>>(
+            static_cast<at::Tensor>(params[i + 1]));
 
-    if (qengine == "fbgemm") {
-#ifdef USE_FBGEMM
-      auto& packed_struct_ih = cpp_custom_type_hack::cast<PackedLinearWeight>(
-          static_cast<at::Tensor>(params[i]));
-      auto& packed_struct_hh = cpp_custom_type_hack::cast<PackedLinearWeight>(
-          static_cast<at::Tensor>(params[i + 1]));
-
-      bias_ih = packed_struct_ih.bias.value_or(undefined);
-      bias_hh = packed_struct_hh.bias.value_or(undefined);
-#endif
-      } else if (qengine == "qnnpack") {
-#ifdef USE_PYTORCH_QNNPACK
-        auto& packed_struct_ih =
-            cpp_custom_type_hack::cast<PackedLinearWeightsQnnp>(
-                static_cast<at::Tensor>(params[i]));
-        auto& packed_struct_hh =
-            cpp_custom_type_hack::cast<PackedLinearWeightsQnnp>(
-                static_cast<at::Tensor>(params[i + 1]));
-
-        bias_ih = packed_struct_ih.bias;
-        bias_hh = packed_struct_hh.bias;
-#endif
-      }
-      result.emplace_back(c10::make_intrusive<QuantizedCellParamsDynamic>(
-          static_cast<at::Tensor>(params[i]),
-          static_cast<at::Tensor>(params[i + 1]),
-          bias_ih,
-          bias_hh));
-    }
-    return result;
-}
-
-static c10::List<c10::intrusive_ptr<CellParamsBase>>
-gather_quantized_params_dynamic(c10::List<at::Tensor> params) {
-  TORCH_CHECK(
-      params.size() % 2 == 0,
-      "got an incorrect number of quantized RNN parameters");
-  auto& ctx = at::globalContext();
-#ifdef USE_FBGEMM
-  if (ctx.qEngine() == at::QEngine::FBGEMM){
-    return c10::List<c10::intrusive_ptr<CellParamsBase>>(
-        _quantized_params_dynamic(std::move(params), "fbgemm"));
-}
-#endif
-#ifdef USE_PYTORCH_QNNPACK
-  if (ctx.qEngine() == at::QEngine::QNNPACK) {
-    return c10::List<c10::intrusive_ptr<CellParamsBase>>(
-        _quantized_params_dynamic(std::move(params), "qnnpack"));
+    auto bias_ih = packed_struct_ih->bias().value_or(undefined);
+    auto bias_hh = packed_struct_hh->bias().value_or(undefined);
+    result.emplace_back(c10::make_intrusive<QuantizedCellParamsDynamic>(
+        std::move(packed_struct_ih),
+        std::move(packed_struct_hh),
+        std::move(bias_ih),
+        std::move(bias_hh)));
   }
-#endif
-  TORCH_INTERNAL_ASSERT(false, "Tried to use quantized RNN without FBGEMM or QNNPACK!")
+  return c10::List<c10::intrusive_ptr<CellParamsBase>>(result);
 }
 
 static c10::List<c10::intrusive_ptr<CellParamsBase>>
@@ -698,11 +572,28 @@
   TORCH_CHECK(params.size() % 4 == 0,
               "incorrect number of quantized RNN parameters FP16");
   for (size_t i = 0; i < params.size(); i += 4) {
+    c10::intrusive_ptr<LinearPackedParamsBase> packed_struct_ih =
+        cpp_custom_type_hack::cast<c10::intrusive_ptr<LinearPackedParamsBase>>(
+            static_cast<at::Tensor>(params[i]));
+    c10::intrusive_ptr<LinearPackedParamsBase> packed_struct_hh =
+        cpp_custom_type_hack::cast<c10::intrusive_ptr<LinearPackedParamsBase>>(
+            static_cast<at::Tensor>(params[i + 1]));
+
+    // NB: we install the bias from the gathered parameters here because
+    // in the "new world", the fp16 linear apply() method always expects
+    // the bias to be present in the packed struct. In the "old world",
+    // we called `fbgemm_linear_fp16_weight_fp32_activation`, which took
+    // the bias explicitly and ignored the bias in the packed struct. To
+    // reconcile serialized models that behavied in the old style, we
+    // put the bias into the appropriate packed structures here.
+    //
+    // Hopefully we can remove this in the future when we eliminate
+    // the old style altogether
+    packed_struct_ih->set_bias(params[i + 2]);
+    packed_struct_hh->set_bias(params[i + 3]);
+
     result.emplace_back(c10::make_intrusive<QuantizedCellParamsFP16>(
-        static_cast<at::Tensor>(params[i]),
-        static_cast<at::Tensor>(params[i + 1]),
-        static_cast<at::Tensor>(params[i + 2]),
-        static_cast<at::Tensor>(params[i + 3])));
+        std::move(packed_struct_ih), std::move(packed_struct_hh)));
   }
   return c10::List<c10::intrusive_ptr<CellParamsBase>>(result);
 }
@@ -1910,6 +1801,8 @@
 
 namespace {
 
+static auto ensure_linear_params_registered = register_linear_params();
+
 static auto cell_params_base_registry =
     torch::class_<CellParamsBase>("rnn", "CellParamsBase")
         .def_pickle(
@@ -1942,17 +1835,17 @@
                 .kernel<
                     decltype(quantized_lstm_data_legacy),
                     quantized_lstm_data_legacy>(DispatchKey::CPUTensorId))
-        .op("quantized::make_quantized_cell_params_dynamic(Tensor w_ih, Tensor w_hh, Tensor bias_ih, Tensor bias_hh) -> __torch__.torch.classes.rnn.CellParamsBase",
+        .op("quantized::make_quantized_cell_params_dynamic(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh) -> __torch__.torch.classes.rnn.CellParamsBase",
             torch::RegisterOperators::options()
                 .kernel<
                     decltype(make_quantized_cell_params_dynamic),
                     make_quantized_cell_params_dynamic>(
                     DispatchKey::CPUTensorId))
-        .op("quantized::make_quantized_cell_params_fp16(Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh) -> __torch__.torch.classes.rnn.CellParamsBase",
+        .op("quantized::make_quantized_cell_params_fp16(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh) -> __torch__.torch.classes.rnn.CellParamsBase",
             torch::RegisterOperators::options()
-                .kernel<
+                .catchAllKernel<
                     decltype(make_quantized_cell_params_fp16),
-                    make_quantized_cell_params_fp16>(DispatchKey::CPUTensorId))
+                    &make_quantized_cell_params_fp16>())
         .op("quantized::make_quantized_cell_params(Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh) -> __torch__.torch.classes.rnn.CellParamsBase",
             torch::RegisterOperators::options()
                 .kernel<
diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp
index d0e71e6..b2ec356 100644
--- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp
+++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp
@@ -3,8 +3,6 @@
 #include <ATen/native/quantized/cpu/conv_packed_params.h>
 #include <torch/custom_class.h>
 
-#ifdef USE_FBGEMM
-
 #include <ATen/ATen.h>
 #include <ATen/native/TensorFactories.h>
 #include <ATen/quantized/QTensorImpl.h>
@@ -13,6 +11,15 @@
 #include <c10/core/QScheme.h>
 #include <c10/core/TensorOptions.h>
 
+#include <torch/custom_class.h>
+
+#include <ATen/native/quantized/cpu/packed_params.h>
+#include <ATen/native/quantized/cpu/qnnpack_utils.h>
+
+torch::jit::class_<LinearPackedParamsBase> register_linear_params();
+
+#ifdef USE_FBGEMM
+
 namespace at {
 namespace native {
 namespace fbgemm_utils {
@@ -324,9 +331,64 @@
 template
 CAFFE2_API torch::jit::class_<ConvPackedParamsBase<3>> register_conv_params<3>();
 
+torch::jit::class_<LinearPackedParamsBase> register_linear_params() {
+  using SerializationType = std::tuple<at::Tensor, c10::optional<at::Tensor>>;
+  static auto register_linear_params =
+      torch::jit::class_<LinearPackedParamsBase>(
+          "quantized", "LinearPackedParamsBase")
+          .def_pickle(
+              [](const c10::intrusive_ptr<LinearPackedParamsBase>& params)
+                  -> SerializationType { // __getstate__
+                at::Tensor weight;
+                c10::optional<at::Tensor> bias;
+                std::tie(weight, bias) = params->unpack();
+                return std::make_tuple(std::move(weight), std::move(bias));
+              },
+              [](SerializationType state)
+                  -> c10::intrusive_ptr<
+                      LinearPackedParamsBase> { // __setstate__
+                at::Tensor weight;
+                c10::optional<at::Tensor> bias;
+                weight = std::move(std::get<0>(state));
+                bias = std::move(std::get<1>(state));
+
+#ifdef USE_FBGEMM
+                if (at::globalContext().qEngine() == at::QEngine::FBGEMM) {
+                  if (weight.scalar_type() == at::kQInt8) {
+                    return PackedLinearWeight::prepack(
+                        std::move(weight), std::move(bias));
+                  } else if (weight.scalar_type() == at::kFloat) {
+                    // NB: fp16 weight is serialized as float
+                    return PackedLinearWeightFp16::prepack(
+                        std::move(weight), std::move(bias));
+                  } else {
+                    TORCH_CHECK(
+                        false,
+                        "Unsupported data type",
+                        c10::toString(weight.scalar_type()),
+                        " in serialized LinearPackedParams object!");
+                  }
+                }
+#endif // USE_FBGEMM
+#ifdef USE_PYTORCH_QNNPACK
+                if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
+                  TORCH_CHECK(
+                      weight.scalar_type() == at::kQInt8,
+                      "QNNPACK only supports INT8 bit width currently. Got ",
+                      c10::toString(weight.scalar_type()));
+                  return PackedLinearWeightsQnnp::prepack(
+                      std::move(weight), std::move(bias));
+                }
+#endif // USE_PYTORCH_QNNPACK
+                TORCH_CHECK(false, "Unknown qengine");
+              });
+  return register_linear_params;
+}
+
 namespace {
 
 static auto conv2d_params = register_conv_params<2>();
 static auto conv3d_params = register_conv_params<3>();
+static auto linear_params = register_linear_params();
 
 } // namespace
diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
index 69b5b59..d101391 100644
--- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
+++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
@@ -7,6 +7,7 @@
 
 #include <ATen/Tensor.h>
 #include <ATen/native/quantized/cpu/conv_packed_params.h>
+#include <ATen/native/quantized/cpu/packed_params.h>
 #include <c10/core/QScheme.h>
 
 
@@ -19,18 +20,100 @@
 // of the A rows. The column offsets are needed for the asymmetric quantization
 // (affine quantization) of input matrix.
 // Note that in JIT mode we can think of a way to fuse col_offsets with bias.
-struct CAFFE2_API PackedLinearWeight {
+struct CAFFE2_API PackedLinearWeight : public LinearPackedParamsBase {
+  PackedLinearWeight(
+      std::unique_ptr<fbgemm::PackBMatrix<int8_t>> w,
+      c10::optional<at::Tensor> bias,
+      std::vector<int32_t> col_offsets,
+      std::vector<float> w_scale,
+      std::vector<int32_t> w_zp,
+      c10::QScheme q_scheme)
+      : w(std::move(w)),
+        bias_(std::move(bias)),
+        col_offsets(std::move(col_offsets)),
+        w_scale(std::move(w_scale)),
+        w_zp(std::move(w_zp)),
+        q_scheme(std::move(q_scheme)) {}
   std::unique_ptr<fbgemm::PackBMatrix<int8_t>> w;
-  c10::optional<at::Tensor> bias;
+  c10::optional<at::Tensor> bias_;
   std::vector<int32_t> col_offsets;
   std::vector<float> w_scale;
   std::vector<int32_t> w_zp;
   c10::QScheme q_scheme;
+
+  at::Tensor apply(
+      at::Tensor input,
+      double output_scale,
+      int64_t output_zero_point) override;
+  at::Tensor apply_relu(
+      at::Tensor input,
+      double output_scale,
+      int64_t output_zero_point) override;
+
+  at::Tensor apply_dynamic(at::Tensor input) override;
+  at::Tensor apply_dynamic_relu(at::Tensor input) override;
+
+  std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
+
+  c10::optional<at::Tensor> bias() override {
+    return bias_;
+  }
+
+  static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
+      at::Tensor weight,
+      c10::optional<at::Tensor> bias);
+
+ private:
+  template <bool ReluFused>
+  at::Tensor apply_impl(
+      at::Tensor input,
+      double output_scale,
+      int64_t output_zero_point);
+
+  template <bool ReluFused>
+  at::Tensor apply_dynamic_impl(at::Tensor input);
 };
 
-struct CAFFE2_API PackedLinearWeightFp16 {
+struct CAFFE2_API PackedLinearWeightFp16 : public LinearPackedParamsBase {
+  PackedLinearWeightFp16(
+      std::unique_ptr<fbgemm::PackedGemmMatrixFP16> w,
+      c10::optional<at::Tensor> bias)
+      : w(std::move(w)), bias_(std::move(bias)) {}
+
   std::unique_ptr<fbgemm::PackedGemmMatrixFP16> w;
-  c10::optional<at::Tensor> bias;
+  c10::optional<at::Tensor> bias_;
+
+  at::Tensor apply(
+      at::Tensor input,
+      double output_scale,
+      int64_t output_zero_point) override {
+    TORCH_INTERNAL_ASSERT(false);
+  }
+  at::Tensor apply_relu(
+      at::Tensor input,
+      double output_scale,
+      int64_t output_zero_point) override {
+    TORCH_INTERNAL_ASSERT(false);
+  }
+
+  at::Tensor apply_dynamic(at::Tensor input) override;
+  at::Tensor apply_dynamic_relu(at::Tensor input) override;
+
+  std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
+
+  c10::optional<at::Tensor> bias() override {
+    return bias_;
+  }
+
+  static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
+      at::Tensor weight,
+      c10::optional<at::Tensor> bias);
+
+  void set_bias(c10::optional<at::Tensor> bias) override;
+
+ private:
+  template <bool ReluFused>
+  at::Tensor apply_dynamic_impl(at::Tensor input);
 };
 
 template <int kSpatialDim = 2>
diff --git a/aten/src/ATen/native/quantized/cpu/packed_params.h b/aten/src/ATen/native/quantized/cpu/packed_params.h
new file mode 100644
index 0000000..f13b6c1
--- /dev/null
+++ b/aten/src/ATen/native/quantized/cpu/packed_params.h
@@ -0,0 +1,27 @@
+#pragma once
+
+#include <ATen/core/ivalue.h>
+
+struct LinearPackedParamsBase : public torch::jit::CustomClassHolder {
+  virtual at::Tensor apply(
+      at::Tensor input,
+      double output_scale,
+      int64_t output_zero_point) = 0;
+  virtual at::Tensor apply_relu(
+      at::Tensor input,
+      double output_scale,
+      int64_t output_zero_point) = 0;
+
+  virtual at::Tensor apply_dynamic(at::Tensor input) = 0;
+  virtual at::Tensor apply_dynamic_relu(at::Tensor input) = 0;
+
+  virtual std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() = 0;
+
+  virtual c10::optional<at::Tensor> bias() = 0;
+
+  virtual void set_bias(c10::optional<at::Tensor> bias) {
+    throw std::runtime_error(
+        "set_bias is not implemented for this packed "
+        "parameter type");
+  }
+};
diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp
index 85e486a..3c2614d 100644
--- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp
@@ -1,14 +1,358 @@
 #include <ATen/ATen.h>
 #include <ATen/Parallel.h>
-#include <torch/library.h>
-#include <ATen/cpp_custom_type_hack.h>
+#include <ATen/core/op_registration/op_registration.h>
 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
+#include <ATen/native/quantized/cpu/packed_params.h>
 #include <ATen/native/quantized/cpu/qnnpack_utils.h>
 #include <caffe2/utils/threadpool/ThreadPoolMobile.h>
+#include <torch/custom_class.h>
+#include <torch/library.h>
 
 #include <algorithm>
 #include <string>
 
+torch::jit::class_<LinearPackedParamsBase> register_linear_params();
+
+#ifdef USE_FBGEMM
+template <bool ReluFused>
+at::Tensor PackedLinearWeight::apply_impl(
+    at::Tensor input,
+    double output_scale,
+    int64_t output_zero_point) {
+  // uint8 * int8 -> uint8 (no quantization/dequantization)
+
+  // We make a strong guarantee that models using these operators will have
+  // the same numerics across different machines. Therefore, we do not provide
+  // a fallback path and rather fail loudly if we cannot run FBGEMM.
+  TORCH_CHECK(
+      fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
+
+  // TODO: contiguous is called for further jit optimizations.
+  auto input_contig = input.contiguous();
+  const auto* input_ptr =
+      reinterpret_cast<uint8_t*>(input_contig.data_ptr<c10::quint8>());
+
+  TORCH_CHECK(
+      input.dim() >= 2,
+      "The dimension of input tensor should be larger than or equal to 2");
+  // C(output) = A(input) x B(weight), where C, A, B are M x N, M x K, K x N
+  // matrices, respectively.
+  int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
+
+  auto packB = w.get();
+
+  int64_t N = static_cast<int64_t>(packB->numCols());
+  int64_t K = input.size(input.dim() - 1);
+  TORCH_CHECK(
+      K == static_cast<int64_t>(packB->numRows()),
+      "The number of rows in the packB should be equal to K: " +
+          std::to_string(K));
+
+  float input_scale_float = input.q_scale();
+  int32_t input_zero_point_int32 = input.q_zero_point();
+
+  std::vector<float> output_multiplier_float(1, 0.0);
+  std::vector<float> act_times_w_scale(1, 0.0);
+  TORCH_CHECK(
+      w_scale.size() == w_zp.size(),
+      "Weight scales and zero points vectors should have the same size.");
+  if (q_scheme == c10::kPerTensorAffine) {
+    // Process the per tensor quantization.
+    act_times_w_scale[0] = (input_scale_float * w_scale[0]);
+    output_multiplier_float[0] =
+        act_times_w_scale[0] / static_cast<float>(output_scale);
+  } else if (q_scheme == c10::kPerChannelAffine) {
+    // Process the per channel quantization.
+    output_multiplier_float.resize(N, 0.0);
+    act_times_w_scale.resize(N, 1.0f);
+    for (int i = 0; i < N; ++i) {
+      act_times_w_scale[i] = (input_scale_float * w_scale[i]);
+      output_multiplier_float[i] =
+          act_times_w_scale[i] / static_cast<float>(output_scale);
+    }
+  }
+  int32_t output_zero_point_int32 = static_cast<int32_t>(output_zero_point);
+
+  const float* bias_ptr = nullptr;
+  at::Tensor bias;
+  if (this->bias_.has_value()) {
+    bias = this->bias_.value();
+    bias = bias.contiguous();
+    TORCH_CHECK(bias.dim() == 1, "bias should be a vector (1D Tensor)");
+    TORCH_CHECK(
+        bias.size(0) == N, "bias should have N elements: " + std::to_string(N));
+    bias_ptr = reinterpret_cast<float*>(bias.data_ptr<float>());
+  }
+
+  // The resulting matrix here is 2-D, let's view it with the original
+  // left hand dimensions of the input. Here are two examples:
+  // 1. If the input tensor is {M, K}, the output tensor is {M, N}.
+  // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
+  std::vector<int64_t> out_sizes = input.sizes().vec();
+  out_sizes.back() = N;
+  // Allocate output Tensor and a buffer for fbgemmPacked to use
+  auto output = at::_empty_affine_quantized(
+      out_sizes,
+      at::device(c10::kCPU).dtype(c10::kQUInt8),
+      output_scale,
+      output_zero_point);
+
+  auto buffer = at::empty(out_sizes, output.options().dtype(at::kInt));
+
+  int num_tasks = at::get_num_threads();
+  at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
+    for (int task_id = begin; task_id < end; ++task_id) {
+      // This operation does the following:
+      // 1) Creates a "row buffer" vector with offset values that must be
+      //    added to the integer matrix multiplication operation to ensure
+      //    correctness. This "row buffer" is also called the row offset, and
+      //    it is needed when we use affine quantization for weights.
+      // 2) Packs the resulting quantized matrix into vector-register and
+      //    cache friendly tiles.
+      //
+      //  Note this is not executed eagerly, but rather within the
+      //  fbgemmPacked call below.
+      fbgemm::PackAWithRowOffset<uint8_t> packA(
+          /*trans=*/fbgemm::matrix_op_t::NoTranspose,
+          /*nRow=*/M,
+          /*nCol=*/K,
+          /*smat=*/input_ptr,
+          /*ld=*/K,
+          /*pmat=*/nullptr); // Currently, packA manages ownership of `pmat`.
+                             // TODO: Consider a way to pre-allocate and reuse
+                             // pmat buffer.
+
+      // ReQuantizeOutput requires pointers to the zero point values,
+      // since in the case of rowwise quantization these will be arrays rather
+      // than scalars. But in this case, we're doing whole-tensor quantization
+      // so we just pass a pointer to the scale values (and internally
+      // ReQuantizeOutput won't index past 0.
+
+      // This is the end of the pipeline, pass the resulting matrix through.
+      fbgemm::DoNothing<> doNothingObj{};
+
+      if (q_scheme == c10::kPerTensorAffine) {
+        // Process the per tensor quantization.
+        //
+        // After the uint8 * int8 matrix multiplication is performed, this
+        // operation does:
+        //  1) Add in row and column offsets to the rows and columns,
+        //  respectively.
+        //  2) Add in the bias term.
+        fbgemm::ReQuantizeOutput<
+            ReluFused,
+            fbgemm::QuantizationGranularity::TENSOR,
+            float>
+            outputProcObj(
+                doNothingObj,
+                output_multiplier_float.data(),
+                output_zero_point_int32,
+                input_zero_point_int32,
+                w_zp.data(),
+                packA.getRowOffsetBuffer(),
+                col_offsets.data(),
+                bias_ptr,
+                N, /* nCol */
+                1 /* groups */,
+                act_times_w_scale.data());
+
+        // Do the GEMM
+        fbgemm::fbgemmPacked(
+            /*packA=*/packA,
+            /*packB=*/*packB,
+            /*C=*/reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
+            /*C_buffer=*/buffer.data_ptr<int32_t>(),
+            /*ldc=*/N,
+            /*outProcess=*/outputProcObj,
+            /*thread_id=*/task_id,
+            /*num_threads=*/num_tasks);
+      } else if (q_scheme == c10::kPerChannelAffine) {
+        // Process the per channel quantization.
+        //
+        // After the uint8 * int8 matrix multiplication is performed, this
+        // operation does:
+        //  1) Add in row and column offsets to the rows and columns,
+        //  respectively.
+        //  2) Add in the bias term.
+        fbgemm::ReQuantizeOutput<
+            ReluFused,
+            fbgemm::QuantizationGranularity::OUT_CHANNEL,
+            float>
+            outputProcObj(
+                doNothingObj,
+                output_multiplier_float.data(),
+                output_zero_point_int32,
+                input_zero_point_int32,
+                w_zp.data(),
+                packA.getRowOffsetBuffer(),
+                col_offsets.data(),
+                bias_ptr,
+                N, /*nCol=*/
+                1, /* groups*/
+                act_times_w_scale.data());
+
+        // Do the GEMM
+        fbgemm::fbgemmPacked(
+            /*packA=*/packA,
+            /*packB=*/*packB,
+            /*C=*/reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
+            /*C_buffer=*/buffer.data_ptr<int32_t>(),
+            /*ldc=*/N,
+            /*outProcess=*/outputProcObj,
+            /*thread_id=*/task_id,
+            /*num_threads=*/num_tasks);
+      }
+    }
+  });
+
+  return output;
+}
+
+at::Tensor PackedLinearWeight::apply(
+    at::Tensor input,
+    double output_scale,
+    int64_t output_zero_point) {
+  return apply_impl<false>(std::move(input), output_scale, output_zero_point);
+}
+
+at::Tensor PackedLinearWeight::apply_relu(
+    at::Tensor input,
+    double output_scale,
+    int64_t output_zero_point) {
+  return apply_impl<true>(std::move(input), output_scale, output_zero_point);
+}
+
+#endif // USE_FBGEMM
+
+#ifdef USE_PYTORCH_QNNPACK
+template <bool ReluFused>
+at::Tensor PackedLinearWeightsQnnp::apply_impl(
+    at::Tensor input,
+    double output_scale,
+    int64_t output_zero_point) {
+  TORCH_CHECK(
+      input.dim() >= 2,
+      "quantized::linear(): Input tensor rank should be >= 2");
+  auto input_contig = input.contiguous();
+
+  auto packB = w.get();
+  // Adjust weight zero point, similar to weight data.
+  auto kernel_zp = w_zp + 128;
+  auto kernel_scale = w_scale;
+  size_t rows_w = bias_.size(0);
+  size_t cols_w = input_contig.size(input_contig.dim() - 1);
+  auto input_scale = input_contig.q_scale();
+
+  if (!this->input_scale.has_value() ||
+      this->input_scale.value() != input_scale) {
+    // Get the original weight and adjust it to uint8 from int8
+    auto weight_contig = orig_weight;
+    auto bias_fp32 = bias_;
+    int8_t* w_data = (int8_t*)weight_contig.data_ptr<c10::qint8>();
+    at::Tensor qnnp_weight = at::_empty_affine_quantized(
+        weight_contig.sizes(),
+        at::device(c10::kCPU).dtype(c10::kQUInt8),
+        kernel_scale,
+        kernel_zp);
+    auto* qnnp_w_data = qnnp_weight.data_ptr<c10::quint8>();
+    auto wt_numel = weight_contig.numel();
+    for (int i = 0; i < wt_numel; ++i) {
+      qnnp_w_data[i] = static_cast<c10::quint8>(w_data[i] + 128);
+    }
+    // Original bias was float, so we requantize it here.
+    auto qbias = at::quantize_per_tensor(
+        bias_fp32, kernel_scale * input_scale, 0, c10::kQInt32);
+    // Update the input scale to not pack again.
+    this->input_scale = input_scale;
+    w.reset();
+    w = std::make_unique<qnnpack::PackBMatrix>(
+        cols_w /* input_channels */,
+        rows_w /* output_channels */,
+        kernel_zp,
+        kernel_scale,
+        (uint8_t*)qnnp_w_data,
+        (int32_t*)qbias.data_ptr<c10::qint32>());
+    packB = w.get();
+    if (at::globalContext().releaseWeightsWhenPrepacking()) {
+      // On mobile, we release the original weight by resetting the intrusive_ptr.
+      // Calling unpack after this will throw an assertion.
+      orig_weight.reset();
+      bias_.reset();
+    }
+  }
+
+  size_t rows_input = 1;
+  size_t cols_input = input_contig.size(input_contig.dim() - 1);
+  for (size_t i = 0; i < input_contig.dim() - 1; ++i) {
+    rows_input *= input_contig.size(i);
+  }
+
+  TORCH_CHECK(
+      cols_input == cols_w,
+      "quantized::linear(): input size does not match weight dimension 1 size: \
+         got ",
+      cols_input,
+      " but expected ",
+      cols_w);
+
+  // Allocate output Tensor and a buffer for QNNPACK to use
+  at::Tensor output = at::_empty_affine_quantized(
+      {static_cast<long>(rows_input), static_cast<long>(rows_w)},
+      input.options(),
+      output_scale,
+      output_zero_point);
+
+  auto output_min = ReluFused
+      ? activationLimits(output_scale, output_zero_point, Activation::RELU)
+            .first
+      : std::numeric_limits<uint8_t>::min();
+  auto output_max = ReluFused
+      ? activationLimits(output_scale, output_zero_point, Activation::RELU)
+            .second
+      : std::numeric_limits<uint8_t>::max();
+  TORCH_INTERNAL_ASSERT(packB != nullptr, "Packed Weights are NULL");
+  const pytorch_qnnp_status runStatus = qnnpack::qnnpackLinear(
+      rows_input /* batch_size */,
+      cols_input /* input_channels */,
+      rows_w /* output_channels */,
+      input_contig.q_zero_point(),
+      input_contig.q_scale(),
+      kernel_zp,
+      kernel_scale,
+      output_zero_point,
+      output_scale,
+      output_min,
+      output_max,
+      (uint8_t*)input_contig.data_ptr<c10::quint8>(),
+      cols_input /* input_stride */,
+      packB->getPackedWeights(),
+      (uint8_t*)output.data_ptr<c10::quint8>(),
+      rows_w /* output_stride */,
+      caffe2::mobile_pthreadpool() /* threadpool */);
+
+  TORCH_INTERNAL_ASSERT(
+      runStatus == pytorch_qnnp_status_success,
+      "failed to run QNNPACK Linear operator");
+
+  return output;
+}
+
+at::Tensor PackedLinearWeightsQnnp::apply(
+    at::Tensor input,
+    double output_scale,
+    int64_t output_zero_point) {
+  return apply_impl<false>(std::move(input), output_scale, output_zero_point);
+}
+
+at::Tensor PackedLinearWeightsQnnp::apply_relu(
+    at::Tensor input,
+    double output_scale,
+    int64_t output_zero_point) {
+  return apply_impl<true>(std::move(input), output_scale, output_zero_point);
+}
+
+#endif // USE_PYTORCH_QNNPACK
+
 namespace at {
 namespace native {
 namespace {
@@ -16,345 +360,18 @@
 template <bool ReluFused>
 class QLinearInt8 final {
  public:
-#ifdef USE_FBGEMM
-  static at::Tensor fbgemm_linear(
-      at::Tensor input,
-      at::Tensor packed_weight,
-      double output_scale,
-      int64_t output_zero_point) {
-    // uint8 * int8 -> uint8 (no quantization/dequantization)
-
-    // We make a strong guarantee that models using these operators will have
-    // the same numerics across different machines. Therefore, we do not provide
-    // a fallback path and rather fail loudly if we cannot run FBGEMM.
-    TORCH_CHECK(
-        fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
-
-    // TODO: contiguous is called for further jit optimizations.
-    auto input_contig = input.contiguous();
-    const auto* input_ptr =
-        reinterpret_cast<uint8_t*>(input_contig.data_ptr<c10::quint8>());
-
-    TORCH_CHECK(
-        input.dim() >= 2,
-        "The dimension of input tensor should be larger than or equal to 2");
-    // C(output) = A(input) x B(weight), where C, A, B are M x N, M x K, K x N
-    // matrices, respectively.
-    int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
-
-    // Pull out the PackBMatrix and col_offsets instance from the owning tensor.
-    auto& pack_ptr =
-        cpp_custom_type_hack::cast<PackedLinearWeight>(packed_weight);
-    auto packB = pack_ptr.w.get();
-    // packB->printPackedMatrix("packedB inside fbgemm_linear (QLinearInt8): ");
-    auto& col_offsets = pack_ptr.col_offsets;
-
-    int64_t N = static_cast<int64_t>(packB->numCols());
-    int64_t K = input.size(input.dim() - 1);
-    TORCH_CHECK(
-        K == static_cast<int64_t>(packB->numRows()),
-        "The number of rows in the packB should be equal to K: " +
-            std::to_string(K));
-
-    float input_scale_float = input.q_scale();
-    int32_t input_zero_point_int32 = input.q_zero_point();
-
-    std::vector<float> output_multiplier_float(1, 0.0);
-    std::vector<float> act_times_w_scale(1, 0.0);
-    TORCH_CHECK(
-        pack_ptr.w_scale.size() == pack_ptr.w_zp.size(),
-        "Weight scales and zero points vectors should have the same size.");
-    if (pack_ptr.q_scheme == kPerTensorAffine) {
-      // Process the per tensor quantization.
-      act_times_w_scale[0] = (input_scale_float * pack_ptr.w_scale[0]);
-      output_multiplier_float[0] =
-          act_times_w_scale[0] / static_cast<float>(output_scale);
-    } else if (pack_ptr.q_scheme == kPerChannelAffine) {
-      // Process the per channel quantization.
-      output_multiplier_float.resize(N, 0.0);
-      act_times_w_scale.resize(N, 1.0f);
-      for (int i = 0; i < N; ++i) {
-        act_times_w_scale[i] = (input_scale_float * pack_ptr.w_scale[i]);
-        output_multiplier_float[i] =
-            act_times_w_scale[i] / static_cast<float>(output_scale);
-      }
-    }
-    int32_t output_zero_point_int32 = static_cast<int32_t>(output_zero_point);
-
-    const float* bias_ptr = nullptr;
-    at::Tensor bias;
-    if (pack_ptr.bias.has_value()) {
-      bias = pack_ptr.bias.value();
-      bias = bias.contiguous();
-      TORCH_CHECK(bias.dim() == 1, "bias should be a vector (1D Tensor)");
-      TORCH_CHECK(
-          bias.size(0) == N,
-          "bias should have N elements: " + std::to_string(N));
-      bias_ptr = reinterpret_cast<float*>(bias.data_ptr<float>());
-    }
-
-    // The resulting matrix here is 2-D, let's view it with the original
-    // left hand dimensions of the input. Here are two examples:
-    // 1. If the input tensor is {M, K}, the output tensor is {M, N}.
-    // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
-    std::vector<int64_t> out_sizes = input.sizes().vec();
-    out_sizes.back() = N;
-    // Allocate output Tensor and a buffer for fbgemmPacked to use
-    auto output = _empty_affine_quantized(
-        out_sizes,
-        at::device(kCPU).dtype(kQUInt8),
-        output_scale,
-        output_zero_point);
-
-    auto buffer = at::empty(out_sizes, output.options().dtype(at::kInt));
-
-    int num_tasks = at::get_num_threads();
-    at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
-      for (int task_id = begin; task_id < end; ++task_id) {
-        // This operation does the following:
-        // 1) Creates a "row buffer" vector with offset values that must be
-        //    added to the integer matrix multiplication operation to ensure
-        //    correctness. This "row buffer" is also called the row offset, and
-        //    it is needed when we use affine quantization for weights.
-        // 2) Packs the resulting quantized matrix into vector-register and
-        //    cache friendly tiles.
-        //
-        //  Note this is not executed eagerly, but rather within the
-        //  fbgemmPacked call below.
-        fbgemm::PackAWithRowOffset<uint8_t> packA(
-            /*trans=*/fbgemm::matrix_op_t::NoTranspose,
-            /*nRow=*/M,
-            /*nCol=*/K,
-            /*smat=*/input_ptr,
-            /*ld=*/K,
-            /*pmat=*/nullptr); // Currently, packA manages ownership of `pmat`.
-                               // TODO: Consider a way to pre-allocate and reuse
-                               // pmat buffer.
-
-        // ReQuantizeOutput requires pointers to the zero point values,
-        // since in the case of rowwise quantization these will be arrays rather
-        // than scalars. But in this case, we're doing whole-tensor quantization
-        // so we just pass a pointer to the scale values (and internally
-        // ReQuantizeOutput won't index past 0.
-
-        // This is the end of the pipeline, pass the resulting matrix through.
-        fbgemm::DoNothing<> doNothingObj{};
-
-        if (pack_ptr.q_scheme == kPerTensorAffine) {
-          // Process the per tensor quantization.
-          //
-          // After the uint8 * int8 matrix multiplication is performed, this
-          // operation does:
-          //  1) Add in row and column offsets to the rows and columns,
-          //  respectively.
-          //  2) Add in the bias term.
-          fbgemm::ReQuantizeOutput<
-              ReluFused,
-              fbgemm::QuantizationGranularity::TENSOR,
-              float>
-              outputProcObj(
-                  doNothingObj,
-                  output_multiplier_float.data(),
-                  output_zero_point_int32,
-                  input_zero_point_int32,
-                  pack_ptr.w_zp.data(),
-                  packA.getRowOffsetBuffer(),
-                  col_offsets.data(),
-                  bias_ptr,
-                  N, /* nCol */
-                  1 /* groups */,
-                  act_times_w_scale.data());
-
-          // Do the GEMM
-          fbgemm::fbgemmPacked(
-              /*packA=*/packA,
-              /*packB=*/*packB,
-              /*C=*/reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
-              /*C_buffer=*/buffer.data_ptr<int32_t>(),
-              /*ldc=*/N,
-              /*outProcess=*/outputProcObj,
-              /*thread_id=*/task_id,
-              /*num_threads=*/num_tasks);
-        } else if (pack_ptr.q_scheme == kPerChannelAffine) {
-          // Process the per channel quantization.
-          //
-          // After the uint8 * int8 matrix multiplication is performed, this
-          // operation does:
-          //  1) Add in row and column offsets to the rows and columns,
-          //  respectively.
-          //  2) Add in the bias term.
-          fbgemm::ReQuantizeOutput<
-              ReluFused,
-              fbgemm::QuantizationGranularity::OUT_CHANNEL,
-              float>
-              outputProcObj(
-                  doNothingObj,
-                  output_multiplier_float.data(),
-                  output_zero_point_int32,
-                  input_zero_point_int32,
-                  pack_ptr.w_zp.data(),
-                  packA.getRowOffsetBuffer(),
-                  col_offsets.data(),
-                  bias_ptr,
-                  N, /*nCol=*/
-                  1, /* groups*/
-                  act_times_w_scale.data());
-
-          // Do the GEMM
-          fbgemm::fbgemmPacked(
-              /*packA=*/packA,
-              /*packB=*/*packB,
-              /*C=*/reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
-              /*C_buffer=*/buffer.data_ptr<int32_t>(),
-              /*ldc=*/N,
-              /*outProcess=*/outputProcObj,
-              /*thread_id=*/task_id,
-              /*num_threads=*/num_tasks);
-        }
-      }
-    });
-
-    return output;
-  }
-#endif
-#ifdef USE_PYTORCH_QNNPACK
-  static at::Tensor qnnpack_linear(
-      at::Tensor input,
-      at::Tensor packed_weight,
-      double output_scale,
-      int64_t output_zero_point) {
-    TORCH_CHECK(
-        input.dim() >= 2,
-        "quantized::linear(): Input tensor rank should be >= 2");
-    auto input_contig = input.contiguous();
-
-    auto& pack_ptr =
-        cpp_custom_type_hack::cast<PackedLinearWeightsQnnp>(packed_weight);
-    auto packB = pack_ptr.w.get();
-    // Adjust weight zero point, similar to weight data.
-    auto kernel_zp = pack_ptr.w_zp + 128;
-    auto kernel_scale = pack_ptr.w_scale;
-    size_t rows_w = pack_ptr.bias.size(0);
-    size_t cols_w = input_contig.size(input_contig.dim() - 1);
-    auto input_scale = input_contig.q_scale();
-
-    if (!pack_ptr.input_scale.has_value() ||
-        pack_ptr.input_scale.value() != input_scale) {
-      // Get the original weight and adjust it to uint8 from int8
-      auto weight_contig = pack_ptr.orig_weight;
-      auto bias_fp32 = pack_ptr.bias;
-      int8_t* w_data = (int8_t*)weight_contig.data_ptr<c10::qint8>();
-      Tensor qnnp_weight = at::_empty_affine_quantized(
-          weight_contig.sizes(),
-          at::device(kCPU).dtype(kQUInt8),
-          kernel_scale,
-          kernel_zp);
-      auto* qnnp_w_data = qnnp_weight.data_ptr<c10::quint8>();
-      auto wt_numel = weight_contig.numel();
-      for (int i = 0; i < wt_numel; ++i) {
-        qnnp_w_data[i] = static_cast<c10::quint8>(w_data[i] + 128);
-      }
-      // Original bias was float, so we requantize it here.
-      auto qbias = at::quantize_per_tensor(
-          bias_fp32, kernel_scale * input_scale, 0, kQInt32);
-      // Update the input scale to not pack again.
-      pack_ptr.input_scale = input_scale;
-      pack_ptr.w.reset();
-      pack_ptr.w = std::make_unique<qnnpack::PackBMatrix>(
-          cols_w /* input_channels */,
-          rows_w /* output_channels */,
-          kernel_zp,
-          kernel_scale,
-          (uint8_t*)qnnp_w_data,
-          (int32_t*)qbias.data_ptr<c10::qint32>());
-      packB = pack_ptr.w.get();
-      if (at::globalContext().releaseWeightsWhenPrepacking()) {
-        // On mobile, we release the original weight by resetting the intrusive_ptr.
-        // Calling unpack after this will throw an assertion.
-        pack_ptr.orig_weight.reset();
-        pack_ptr.bias.reset();
-      }
-    }
-
-    size_t rows_input = 1;
-    size_t cols_input = input_contig.size(input_contig.dim() - 1);
-    for (size_t i = 0; i < input_contig.dim() - 1; ++i) {
-      rows_input *= input_contig.size(i);
-    }
-
-    TORCH_CHECK(
-        cols_input == cols_w,
-        "quantized::linear(): input size does not match weight dimension 1 size: \
-         got ",
-        cols_input,
-        " but expected ",
-        cols_w);
-
-    // Allocate output Tensor and a buffer for QNNPACK to use
-    Tensor output = at::_empty_affine_quantized(
-        {static_cast<long>(rows_input), static_cast<long>(rows_w)},
-        input.options(),
-        output_scale,
-        output_zero_point);
-
-    auto output_min = ReluFused
-        ? activationLimits(output_scale, output_zero_point, Activation::RELU)
-              .first
-        : std::numeric_limits<uint8_t>::min();
-    auto output_max = ReluFused
-        ? activationLimits(output_scale, output_zero_point, Activation::RELU)
-              .second
-        : std::numeric_limits<uint8_t>::max();
-    TORCH_INTERNAL_ASSERT(packB != nullptr, "Packed Weights are NULL");
-    const pytorch_qnnp_status runStatus = qnnpack::qnnpackLinear(
-        rows_input /* batch_size */,
-        cols_input /* input_channels */,
-        rows_w /* output_channels */,
-        input_contig.q_zero_point(),
-        input_contig.q_scale(),
-        kernel_zp,
-        kernel_scale,
-        output_zero_point,
-        output_scale,
-        output_min,
-        output_max,
-        (uint8_t*)input_contig.data_ptr<c10::quint8>(),
-        cols_input /* input_stride */,
-        packB->getPackedWeights(),
-        (uint8_t*)output.data_ptr<c10::quint8>(),
-        rows_w /* output_stride */,
-        caffe2::mobile_pthreadpool() /* threadpool */);
-
-    TORCH_INTERNAL_ASSERT(
-        runStatus == pytorch_qnnp_status_success,
-        "failed to run QNNPACK Linear operator");
-
-    return output;
-  }
-#endif
   static at::Tensor run(
       at::Tensor input,
-      at::Tensor packed_weight,
+      const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
       double output_scale,
       int64_t output_zero_point) {
-    auto& ctx = at::globalContext();
-
-#ifdef USE_FBGEMM
-    if (ctx.qEngine() == at::QEngine::FBGEMM) {
-      return fbgemm_linear(
-          input, packed_weight, output_scale, output_zero_point);
+    if (ReluFused) {
+      return packed_weight->apply_relu(
+          std::move(input), output_scale, output_zero_point);
+    } else {
+      return packed_weight->apply(
+          std::move(input), output_scale, output_zero_point);
     }
-#endif
-#ifdef USE_PYTORCH_QNNPACK
-    if (ctx.qEngine() == at::QEngine::QNNPACK) {
-      return qnnpack_linear(
-          input, packed_weight, output_scale, output_zero_point);
-    }
-#endif
-    TORCH_CHECK(
-        false,
-        "Didn't find engine for operation quantized::linear ",
-        toString(ctx.qEngine()));
   }
 };
 
diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp
index c9313c7..c7a9045 100644
--- a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp
@@ -1,15 +1,389 @@
 #include <ATen/ATen.h>
 #include <ATen/Parallel.h>
-#include <torch/library.h>
-#include <ATen/cpp_custom_type_hack.h>
+#include <ATen/core/op_registration/op_registration.h>
 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
+#include <ATen/native/quantized/cpu/packed_params.h>
 #include <ATen/native/quantized/cpu/qnnpack_utils.h>
 #include <ATen/native/quantized/cpu/quant_utils.h>
 #include <caffe2/utils/threadpool/ThreadPoolMobile.h>
+#include <torch/library.h>
+
+#include <torch/custom_class.h>
 
 #include <algorithm>
 #include <string>
 
+torch::jit::class_<LinearPackedParamsBase> register_linear_params();
+
+#ifdef USE_FBGEMM
+template <bool ReluFused>
+at::Tensor PackedLinearWeight::apply_dynamic_impl(at::Tensor input) {
+  using at::Tensor;
+  // fp32 * int8 -> fp32 (with quantization on activation, and dequantization
+  // on the result).
+
+  // We make a strong guarantee that models using these operators will have
+  // the same numerics across different machines. Therefore, we do not provide
+  // a fallback path and rather fail loudly if we cannot run FBGEMM.
+  TORCH_CHECK(
+      fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
+
+  // TODO: contiguous is called for further jit optimizations.
+  auto input_contig = input.contiguous();
+  const auto* input_ptr = input_contig.data_ptr<float>();
+
+  TORCH_CHECK(
+      input.dim() >= 2,
+      "The dimension of input tensor should be larger than or equal to 2");
+  // C(output) = A(input) x B(weight), where C, A, B are M x N, M x K, K x N
+  // matrices, respectively.
+  int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
+
+  auto packB = w.get();
+
+  int64_t N = static_cast<int64_t>(packB->numCols());
+  int64_t K = input.size(input.dim() - 1);
+  TORCH_CHECK(
+      K == static_cast<int64_t>(packB->numRows()),
+      "The number of rows in the packB should be equal to K: " +
+          std::to_string(K));
+
+  // Calculate statistics for quantization of the input Tensor
+  float x_min, x_max;
+  fbgemm::FindMinMax(
+      /*m=*/input_ptr,
+      /*min=*/&x_min,
+      /*max=*/&x_max,
+      /*len=*/input.numel());
+
+  // Input tensor is quantized as 8-bit unsigned values
+  static constexpr int precision = 8;
+  static constexpr bool is_signed = false;
+
+  // Calculate scale and zero point for quantization of input tensor
+  auto q_params = quant_utils::ChooseQuantizationParams(
+      /*min=*/x_min,
+      /*max=*/x_max,
+      /*qmin=*/is_signed ? -(1 << (precision - 1)) : 0,
+      /*qmax=*/
+      is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1,
+      /*preserve_sparsity=*/false);
+
+  q_params.precision = precision;
+
+  // ReQuantizeForFloat requires pointers to the zero point values,
+  // since in the case of rowwise quantization these will be arrays rather
+  // than scalars. But in this case, we're doing whole-tensor quantization so
+  // we just pass a pointer to the scale values (and internally
+  // ReQuantizeForFloat won't index past 0.
+
+  const float* bias_ptr = nullptr;
+  at::Tensor bias_vec;
+  if (bias_.has_value()) {
+    bias_vec = bias_.value();
+    TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)");
+    TORCH_CHECK(
+        bias_vec.size(0) == N,
+        "bias should have N elements: " + std::to_string(N));
+    // TODO: contiguous is called for further jit optimizations.
+    auto bias_contig = bias_vec.contiguous();
+    bias_ptr = bias_contig.data_ptr<float>();
+  }
+  // The resulting matrix here is 2-D, let's view it with the original
+  // left hand dimensions of the input. Here are two examples:
+  // 1. If the input tensor is {M, K}, the output tensor is {M, N}.
+  // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
+  std::vector<int64_t> out_sizes = input.sizes().vec();
+  out_sizes.back() = N;
+  // Allocate output Tensor and a buffer for fbgemmPacked to use
+  auto output = at::empty(out_sizes, input.options().dtype(at::kFloat));
+  auto buffer = at::empty_like(
+      output,
+      output.options().dtype(at::kInt),
+      LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+
+  int num_tasks = at::get_num_threads();
+  at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
+    // This operation does the following:
+    // 1) Quantizes the input matrix given the statistics we've calculated
+    // above
+    // 2) Creates a "row buffer" vector with offset values that must be
+    // added
+    //    to the integer matrix multiplication operation to ensure
+    //    correctness. This "row buffer" is also called the row offset, and it
+    //    is needed when we use affine quantization for weights.
+    // 3) Packs the resulting quantized matrix into vector-register and cache
+    //    friendly tiles.
+    //
+    //  Note this is not executed eagerly, but rather within the fbgemmPacked
+    //  call below.
+
+    fbgemm::PackAWithQuantRowOffset<uint8_t> packA(
+        /*trans=*/fbgemm::matrix_op_t::NoTranspose,
+        /*nRow=*/M,
+        /*nCol=*/K,
+        /*smat=*/input_ptr,
+        /*ld=*/K,
+        /*pmat=*/nullptr, // Currently, packA manages ownership of `pmat`.
+        /*scale=*/q_params.scale,
+        /*zero_pt=*/q_params.zero_point);
+    // TODO: Consider a way to pre-allocate and reuse
+    // pmat buffer.
+
+    // This is the end of the pipeline, pass the resulting matrix through.
+    fbgemm::DoNothing<float, float> doNothingObj{};
+
+    for (int task_id = begin; task_id < end; ++task_id) {
+      if (q_scheme == c10::kPerTensorAffine) {
+        // Process the per tensor quantization.
+        //
+        // After the uint8 * int8 matrix multiplication is performed, this
+        // operation does:
+        //  1) Add in row and column offsets to the rows and columns,
+        //  respectively.
+        //  2) Dequantize the results into floating point.
+        //  3) Add in the bias term.
+        fbgemm::ReQuantizeForFloat<ReluFused> outputProcObj(
+            /*nextop=*/doNothingObj,
+            /*Aq_scale=*/q_params.scale,
+            /*Bq_scale=*/w_scale.data(),
+            /*Aq_zero_point=*/q_params.zero_point,
+            /*Bq_zero_point=*/w_zp.data(),
+            /*row_offsets=*/packA.getRowOffsetBuffer(),
+            /*col_offsets=*/col_offsets.data(),
+            /*bias=*/bias_ptr,
+            /*nCol=*/N);
+
+        // Do the GEMM
+        fbgemm::fbgemmPacked(
+            /*packA=*/packA,
+            /*packB=*/*packB,
+            /*C=*/output.data_ptr<float>(),
+            /*C_buffer=*/buffer.data_ptr<int32_t>(),
+            /*ldc=*/N,
+            /*outProcess=*/outputProcObj,
+            /*thread_id=*/task_id,
+            /*num_threads=*/num_tasks);
+
+      } else if (q_scheme == c10::kPerChannelAffine) {
+        // Process the per channel quantization.
+        //
+        // After the uint8 * int8 matrix multiplication is performed, this
+        // operation does:
+        //  1) Add in row and column offsets to the rows and columns,
+        //  respectively.
+        //  2) Dequantize the results into floating point.
+        //  3) Add in the bias term.
+        fbgemm::ReQuantizeForFloat<
+            ReluFused,
+            fbgemm::QuantizationGranularity::OUT_CHANNEL>
+            outputProcObj(
+                /*nextop=*/doNothingObj,
+                /*Aq_scale=*/q_params.scale,
+                /*Bq_scale=*/w_scale.data(),
+                /*Aq_zero_point=*/q_params.zero_point,
+                /*Bq_zero_point=*/w_zp.data(),
+                /*row_offsets=*/packA.getRowOffsetBuffer(),
+                /*col_offsets=*/col_offsets.data(),
+                /*bias=*/bias_ptr,
+                /*nCol=*/N);
+
+        // Do the GEMM
+        fbgemm::fbgemmPacked(
+            /*packA=*/packA,
+            /*packB=*/*packB,
+            /*C=*/output.data_ptr<float>(),
+            /*C_buffer=*/buffer.data_ptr<int32_t>(),
+            /*ldc=*/N,
+            /*outProcess=*/outputProcObj,
+            /*thread_id=*/task_id,
+            /*num_threads=*/num_tasks);
+      }
+    }
+  });
+
+  return output;
+}
+
+at::Tensor PackedLinearWeight::apply_dynamic(at::Tensor input) {
+  return apply_dynamic_impl</*ReluFused=*/false>(std::move(input));
+}
+
+at::Tensor PackedLinearWeight::apply_dynamic_relu(at::Tensor input) {
+  return apply_dynamic_impl</*ReluFused=*/true>(std::move(input));
+}
+
+#endif // USE_FBGEMM
+
+#ifdef USE_PYTORCH_QNNPACK
+template <bool ReluFused>
+at::Tensor PackedLinearWeightsQnnp::apply_dynamic_impl(at::Tensor input) {
+  using at::Tensor;
+  TORCH_CHECK(
+      input.dim() >= 2,
+      "The dimension of input tensor should be larger than or equal to 2");
+  auto input_contig = input.contiguous();
+  // C(output) = A(input) x B(weight), where C, A, B are M x N, M x K, K x N
+  // matrices, respectively.
+
+  auto packB = w.get();
+  // Adjust weight zero point, similar to weight data.
+  auto kernel_zp = w_zp + 128;
+  auto kernel_scale = w_scale;
+  size_t rows_w = bias_.size(0);
+  size_t cols_w = input_contig.size(input_contig.dim() - 1);
+
+  at::Tensor bias_vec = bias_;
+
+  TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)");
+
+  auto bias_contig = bias_vec.contiguous();
+  const float* bias_ptr = bias_contig.data_ptr<float>();
+
+  // Calculate statistics for quantization of input Tensor
+  // TODO: optimized kernel
+  float x_min = input_contig.min().item<float>();
+  float x_max = input_contig.max().item<float>();
+
+  auto q_params = quant_utils::ChooseQuantizationParams(
+      /*min=*/x_min,
+      /*max=*/x_max,
+      /*qmin=*/0,
+      /*qmax=*/255);
+  if (!input_scale.has_value()) {
+    // Get the original weight and adjust it to uint8 from int8
+    auto weight_contig = orig_weight;
+    int8_t* w_data = (int8_t*)weight_contig.data_ptr<c10::qint8>();
+    Tensor qnnp_weight = at::_empty_affine_quantized(
+        weight_contig.sizes(),
+        at::device(c10::kCPU).dtype(c10::kQUInt8),
+        kernel_scale,
+        kernel_zp);
+    auto* qnnp_w_data = qnnp_weight.data_ptr<c10::quint8>();
+    auto wt_numel = weight_contig.numel();
+    for (int i = 0; i < wt_numel; ++i) {
+      qnnp_w_data[i] = static_cast<c10::quint8>(w_data[i] + 128);
+    }
+
+    // Update the input scale to not pack again.
+    // Pass in nullptr for bias, as we pass FP32 bias to run function.
+    input_scale = q_params.scale;
+    w.reset();
+    w = std::make_unique<qnnpack::PackBMatrix>(
+        cols_w /* input_channels */,
+        rows_w /* output_channels */,
+        kernel_zp,
+        kernel_scale,
+        (uint8_t*)qnnp_w_data,
+        nullptr);
+    packB = w.get();
+    if (at::globalContext().releaseWeightsWhenPrepacking()) {
+      // On mobile, we release the original weight by resetting the intrusive_ptr.
+      // Calling unpack after this will throw an assertion.
+      orig_weight.reset();
+    }
+  }
+
+  // Quantize input
+  Tensor q_input = at::quantize_per_tensor(
+      input_contig, q_params.scale, q_params.zero_point, c10::kQUInt8);
+
+  // The resulting matrix here is 2-D, let's view it with the original
+  // left hand dimensions of the input. Here are two examples:
+  // 1. If the input tensor is {M, K}, the output tensor is {M, N}.
+  // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
+  std::vector<int64_t> out_sizes = input.sizes().vec();
+  out_sizes.back() = rows_w;
+
+  auto output = at::empty(out_sizes, input.options().dtype(at::kFloat));
+
+  size_t rows_input = 1;
+  size_t cols_input = input_contig.size(input_contig.dim() - 1);
+  for (size_t i = 0; i < input_contig.dim() - 1; ++i) {
+    rows_input *= input_contig.size(i);
+  }
+  pytorch_qnnp_status runStatus = qnnpack::qnnpackLinearDynamic(
+      rows_input /* batch_size */,
+      cols_input /* input_channels */,
+      rows_w /* output_channels */,
+      q_input.q_zero_point(),
+      q_input.q_scale(),
+      kernel_zp,
+      kernel_scale,
+      (uint8_t*)q_input.data_ptr<c10::quint8>(),
+      cols_input /* input_stride */,
+      packB->getPackedWeights(),
+      bias_ptr,
+      output.data_ptr<float>(),
+      rows_w /* output_stride */,
+      caffe2::mobile_pthreadpool() /* threadpool */);
+
+  TORCH_INTERNAL_ASSERT(
+      runStatus == pytorch_qnnp_status_success,
+      "failed to run QNNPACK Linear operator");
+  return output;
+}
+
+at::Tensor PackedLinearWeightsQnnp::apply_dynamic(at::Tensor input) {
+  return apply_dynamic_impl</*ReluFused=*/false>(std::move(input));
+}
+
+at::Tensor PackedLinearWeightsQnnp::apply_dynamic_relu(at::Tensor input) {
+  return apply_dynamic_impl</*ReluFused=*/true>(std::move(input));
+}
+
+#endif // USE_PYTORCH_QNNPACK
+
+#ifdef USE_FBGEMM
+
+template <bool ReluFused>
+at::Tensor PackedLinearWeightFp16::apply_dynamic_impl(at::Tensor input) {
+  const at::Tensor input_contig = input.contiguous();
+  const float* input_ptr = input_contig.data_ptr<float>();
+
+  auto& packed_weight_fp16 = *w;
+
+  TORCH_CHECK(input.size(input.dim() - 1) == packed_weight_fp16.numRows())
+  TORCH_CHECK(input.dim() >= 2);
+
+  const int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
+  const int64_t N = packed_weight_fp16.numCols();
+  std::vector<int64_t> output_size = input.sizes().vec();
+  output_size.back() = N;
+  at::Tensor output = at::empty(output_size, input.options().dtype(at::kFloat));
+
+  // Call the fp16 gemm interface
+  fbgemm::cblas_gemm_compute(
+      fbgemm::matrix_op_t::NoTranspose,
+      M,
+      input_ptr,
+      packed_weight_fp16,
+      0.0f,
+      output.data_ptr<float>());
+
+  // Add bias term
+  if (bias_.has_value()) {
+    TORCH_CHECK(bias_->dim() == 1);
+    output.add_(*bias_);
+  }
+
+  return output;
+}
+
+at::Tensor PackedLinearWeightFp16::apply_dynamic(at::Tensor input) {
+  return apply_dynamic_impl</*ReluFused=*/false>(std::move(input));
+}
+
+at::Tensor PackedLinearWeightFp16::apply_dynamic_relu(at::Tensor input) {
+  return apply_dynamic_impl</*ReluFused=*/true>(std::move(input));
+}
+
+void PackedLinearWeightFp16::set_bias(c10::optional<at::Tensor> bias) {
+  bias_ = std::move(bias);
+}
+
+#endif // USE_FBGEMM
+
 namespace at {
 namespace native {
 namespace {
@@ -17,328 +391,16 @@
 template <bool ReluFused>
 class QLinearDynamicInt8 final {
  public:
-#ifdef USE_FBGEMM
-  static at::Tensor fbgemm_linear(at::Tensor input, at::Tensor packed_weight) {
-    // fp32 * int8 -> fp32 (with quantization on activation, and dequantization
-    // on the result).
-
-    // We make a strong guarantee that models using these operators will have
-    // the same numerics across different machines. Therefore, we do not provide
-    // a fallback path and rather fail loudly if we cannot run FBGEMM.
-    TORCH_CHECK(
-        fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
-
-    // TODO: contiguous is called for further jit optimizations.
-    auto input_contig = input.contiguous();
-    const auto* input_ptr = input_contig.data_ptr<float>();
-
-    TORCH_CHECK(
-        input.dim() >= 2,
-        "The dimension of input tensor should be larger than or equal to 2");
-    // C(output) = A(input) x B(weight), where C, A, B are M x N, M x K, K x N
-    // matrices, respectively.
-    int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
-
-    // Pull out the PackBMatrix and col_offsets instance from the owning tensor.
-    auto& pack_ptr =
-        cpp_custom_type_hack::cast<PackedLinearWeight>(packed_weight);
-    auto packB = pack_ptr.w.get();
-    // packB->printPackedMatrix("packedB inside fbgemm_linear_dynamic
-    // (QLinearDynamicInt8): ");
-    auto& col_offsets = pack_ptr.col_offsets;
-
-    int64_t N = static_cast<int64_t>(packB->numCols());
-    int64_t K = input.size(input.dim() - 1);
-    TORCH_CHECK(
-        K == static_cast<int64_t>(packB->numRows()),
-        "The number of rows in the packB should be equal to K: " +
-            std::to_string(K));
-
-    // Calculate statistics for quantization of the input Tensor
-    float x_min, x_max;
-    fbgemm::FindMinMax(
-        /*m=*/input_ptr,
-        /*min=*/&x_min,
-        /*max=*/&x_max,
-        /*len=*/input.numel());
-
-    // Input tensor is quantized as 8-bit unsigned values
-    static constexpr int precision = 8;
-    static constexpr bool is_signed = false;
-
-    // Calculate scale and zero point for quantization of input tensor
-    auto q_params = quant_utils::ChooseQuantizationParams(
-        /*min=*/x_min,
-        /*max=*/x_max,
-        /*qmin=*/is_signed ? -(1 << (precision - 1)) : 0,
-        /*qmax=*/
-        is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1,
-        /*preserve_sparsity=*/false);
-
-    q_params.precision = precision;
-
-    // ReQuantizeForFloat requires pointers to the zero point values,
-    // since in the case of rowwise quantization these will be arrays rather
-    // than scalars. But in this case, we're doing whole-tensor quantization so
-    // we just pass a pointer to the scale values (and internally
-    // ReQuantizeForFloat won't index past 0.
-
-    const float* bias_ptr = nullptr;
-    at::Tensor bias_vec;
-    if (pack_ptr.bias.has_value()) {
-      bias_vec = pack_ptr.bias.value();
-      TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)");
-      TORCH_CHECK(
-          bias_vec.size(0) == N,
-          "bias should have N elements: " + std::to_string(N));
-      // TODO: contiguous is called for further jit optimizations.
-      auto bias_contig = bias_vec.contiguous();
-      bias_ptr = bias_contig.data_ptr<float>();
-    }
-    // The resulting matrix here is 2-D, let's view it with the original
-    // left hand dimensions of the input. Here are two examples:
-    // 1. If the input tensor is {M, K}, the output tensor is {M, N}.
-    // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
-    std::vector<int64_t> out_sizes = input.sizes().vec();
-    out_sizes.back() = N;
-    // Allocate output Tensor and a buffer for fbgemmPacked to use
-    auto output = at::empty(out_sizes, input.options().dtype(at::kFloat));
-    auto buffer = at::empty_like(
-        output,
-        output.options().dtype(at::kInt),
-        LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-
-    int num_tasks = at::get_num_threads();
-    at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
-      // This operation does the following:
-      // 1) Quantizes the input matrix given the statistics we've calculated
-      // above
-      // 2) Creates a "row buffer" vector with offset values that must be
-      // added
-      //    to the integer matrix multiplication operation to ensure
-      //    correctness. This "row buffer" is also called the row offset, and it
-      //    is needed when we use affine quantization for weights.
-      // 3) Packs the resulting quantized matrix into vector-register and cache
-      //    friendly tiles.
-      //
-      //  Note this is not executed eagerly, but rather within the fbgemmPacked
-      //  call below.
-
-      fbgemm::PackAWithQuantRowOffset<uint8_t> packA(
-          /*trans=*/fbgemm::matrix_op_t::NoTranspose,
-          /*nRow=*/M,
-          /*nCol=*/K,
-          /*smat=*/input_ptr,
-          /*ld=*/K,
-          /*pmat=*/nullptr, // Currently, packA manages ownership of `pmat`.
-          /*scale=*/q_params.scale,
-          /*zero_pt=*/q_params.zero_point);
-      // TODO: Consider a way to pre-allocate and reuse
-      // pmat buffer.
-
-      // This is the end of the pipeline, pass the resulting matrix through.
-      fbgemm::DoNothing<float, float> doNothingObj{};
-
-      for (int task_id = begin; task_id < end; ++task_id) {
-        if (pack_ptr.q_scheme == kPerTensorAffine) {
-          // Process the per tensor quantization.
-          //
-          // After the uint8 * int8 matrix multiplication is performed, this
-          // operation does:
-          //  1) Add in row and column offsets to the rows and columns,
-          //  respectively.
-          //  2) Dequantize the results into floating point.
-          //  3) Add in the bias term.
-          fbgemm::ReQuantizeForFloat<ReluFused> outputProcObj(
-              /*nextop=*/doNothingObj,
-              /*Aq_scale=*/q_params.scale,
-              /*Bq_scale=*/pack_ptr.w_scale.data(),
-              /*Aq_zero_point=*/q_params.zero_point,
-              /*Bq_zero_point=*/pack_ptr.w_zp.data(),
-              /*row_offsets=*/packA.getRowOffsetBuffer(),
-              /*col_offsets=*/col_offsets.data(),
-              /*bias=*/bias_ptr,
-              /*nCol=*/N);
-
-          // Do the GEMM
-          fbgemm::fbgemmPacked(
-              /*packA=*/packA,
-              /*packB=*/*packB,
-              /*C=*/output.data_ptr<float>(),
-              /*C_buffer=*/buffer.data_ptr<int32_t>(),
-              /*ldc=*/N,
-              /*outProcess=*/outputProcObj,
-              /*thread_id=*/task_id,
-              /*num_threads=*/num_tasks);
-
-        } else if (pack_ptr.q_scheme == kPerChannelAffine) {
-          // Process the per channel quantization.
-          //
-          // After the uint8 * int8 matrix multiplication is performed, this
-          // operation does:
-          //  1) Add in row and column offsets to the rows and columns,
-          //  respectively.
-          //  2) Dequantize the results into floating point.
-          //  3) Add in the bias term.
-          fbgemm::ReQuantizeForFloat<
-              ReluFused,
-              fbgemm::QuantizationGranularity::OUT_CHANNEL>
-              outputProcObj(
-                  /*nextop=*/doNothingObj,
-                  /*Aq_scale=*/q_params.scale,
-                  /*Bq_scale=*/pack_ptr.w_scale.data(),
-                  /*Aq_zero_point=*/q_params.zero_point,
-                  /*Bq_zero_point=*/pack_ptr.w_zp.data(),
-                  /*row_offsets=*/packA.getRowOffsetBuffer(),
-                  /*col_offsets=*/col_offsets.data(),
-                  /*bias=*/bias_ptr,
-                  /*nCol=*/N);
-
-          // Do the GEMM
-          fbgemm::fbgemmPacked(
-              /*packA=*/packA,
-              /*packB=*/*packB,
-              /*C=*/output.data_ptr<float>(),
-              /*C_buffer=*/buffer.data_ptr<int32_t>(),
-              /*ldc=*/N,
-              /*outProcess=*/outputProcObj,
-              /*thread_id=*/task_id,
-              /*num_threads=*/num_tasks);
-        }
-      }
-    });
-
-    return output;
-  }
-#endif // USE_FBGEMM
-#ifdef USE_PYTORCH_QNNPACK
-
-  static at::Tensor qnnpack_linear(at::Tensor input, at::Tensor packed_weight) {
-    TORCH_CHECK(
-        input.dim() >= 2,
-        "The dimension of input tensor should be larger than or equal to 2");
-    auto input_contig = input.contiguous();
-    // C(output) = A(input) x B(weight), where C, A, B are M x N, M x K, K x N
-    // matrices, respectively.
-
-    auto& pack_ptr =
-        cpp_custom_type_hack::cast<PackedLinearWeightsQnnp>(packed_weight);
-    auto packB = pack_ptr.w.get();
-    // Adjust weight zero point, similar to weight data.
-    auto kernel_zp = pack_ptr.w_zp + 128;
-    auto kernel_scale = pack_ptr.w_scale;
-    size_t rows_w = pack_ptr.bias.size(0);
-    size_t cols_w = input_contig.size(input_contig.dim() - 1);
-
-    at::Tensor bias_vec = pack_ptr.bias;
-
-    TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)");
-
-    auto bias_contig = bias_vec.contiguous();
-    const float* bias_ptr = bias_contig.data_ptr<float>();
-
-    // Calculate statistics for quantization of input Tensor
-    // TODO: optimized kernel
-    float x_min = input_contig.min().item<float>();
-    float x_max = input_contig.max().item<float>();
-
-    auto q_params = quant_utils::ChooseQuantizationParams(
-        /*min=*/x_min,
-        /*max=*/x_max,
-        /*qmin=*/0,
-        /*qmax=*/255);
-    if (!pack_ptr.input_scale.has_value()) {
-      // Get the original weight and adjust it to uint8 from int8
-      auto weight_contig = pack_ptr.orig_weight;
-      int8_t* w_data = (int8_t*)weight_contig.data_ptr<c10::qint8>();
-      Tensor qnnp_weight = at::_empty_affine_quantized(
-          weight_contig.sizes(),
-          at::device(kCPU).dtype(kQUInt8),
-          kernel_scale,
-          kernel_zp);
-      auto* qnnp_w_data = qnnp_weight.data_ptr<c10::quint8>();
-      auto wt_numel = weight_contig.numel();
-      for (int i = 0; i < wt_numel; ++i) {
-        qnnp_w_data[i] = static_cast<c10::quint8>(w_data[i] + 128);
-      }
-
-      // Update the input scale to not pack again.
-      // Pass in nullptr for bias, as we pass FP32 bias to run function.
-      pack_ptr.input_scale = q_params.scale;
-      pack_ptr.w.reset();
-      pack_ptr.w = std::make_unique<qnnpack::PackBMatrix>(
-          cols_w /* input_channels */,
-          rows_w /* output_channels */,
-          kernel_zp,
-          kernel_scale,
-          (uint8_t*)qnnp_w_data,
-          nullptr);
-      packB = pack_ptr.w.get();
-      if (at::globalContext().releaseWeightsWhenPrepacking()) {
-        // On mobile, we release the original weight by resetting the intrusive_ptr.
-        // Calling unpack after this will throw an assertion.
-        pack_ptr.orig_weight.reset();
-      }
-    }
-
-    // Quantize input
-    Tensor q_input = at::quantize_per_tensor(
-        input_contig, q_params.scale, q_params.zero_point, kQUInt8);
-
-    // The resulting matrix here is 2-D, let's view it with the original
-    // left hand dimensions of the input. Here are two examples:
-    // 1. If the input tensor is {M, K}, the output tensor is {M, N}.
-    // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
-    std::vector<int64_t> out_sizes = input.sizes().vec();
-    out_sizes.back() = rows_w;
-
-    auto output = at::empty(out_sizes, input.options().dtype(at::kFloat));
-
-    size_t rows_input = 1;
-    size_t cols_input = input_contig.size(input_contig.dim() - 1);
-    for (size_t i = 0; i < input_contig.dim() - 1; ++i) {
-      rows_input *= input_contig.size(i);
-    }
-    pytorch_qnnp_status runStatus = qnnpack::qnnpackLinearDynamic(
-        rows_input /* batch_size */,
-        cols_input /* input_channels */,
-        rows_w /* output_channels */,
-        q_input.q_zero_point(),
-        q_input.q_scale(),
-        kernel_zp,
-        kernel_scale,
-        (uint8_t*)q_input.data_ptr<c10::quint8>(),
-        cols_input /* input_stride */,
-        packB->getPackedWeights(),
-        bias_ptr,
-        output.data_ptr<float>(),
-        rows_w /* output_stride */,
-        caffe2::mobile_pthreadpool() /* threadpool */);
-
-    TORCH_INTERNAL_ASSERT(
-        runStatus == pytorch_qnnp_status_success,
-        "failed to run QNNPACK Linear operator");
-    return output;
-  }
-#endif // USE_PYTORCH_QNNPACK
-
-  static at::Tensor run(at::Tensor input, at::Tensor packed_weight) {
+  static at::Tensor run(
+      at::Tensor input,
+      const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight) {
     auto& ctx = at::globalContext();
 
-#ifdef USE_FBGEMM
-    if (ctx.qEngine() == at::QEngine::FBGEMM) {
-      return fbgemm_linear(input, packed_weight);
+    if (ReluFused) {
+      return packed_weight->apply_dynamic_relu(std::move(input));
+    } else {
+      return packed_weight->apply_dynamic(std::move(input));
     }
-#endif
-#ifdef USE_PYTORCH_QNNPACK
-    if (ctx.qEngine() == at::QEngine::QNNPACK) {
-      return qnnpack_linear(input, packed_weight);
-    }
-#endif
-    TORCH_CHECK(
-        false,
-        "Didn't find engine for operation quantized::linear ",
-        toString(ctx.qEngine()));
   }
 };
 
@@ -346,52 +408,22 @@
 class QLinearDynamicFp16 final {
  public:
 #ifdef USE_FBGEMM
-  static at::Tensor run(at::Tensor input, at::Tensor packed_weight) {
+  static at::Tensor run(
+      at::Tensor input,
+      const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight) {
     // We make a strong guarantee that models using these operators will have
     // the same numerics across different machines. Therefore, we do not provide
     // a fallback path and rather fail loudly if we cannot run FBGEMM.
     TORCH_CHECK(
         fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
 
-    const Tensor input_contig = input.contiguous();
-    const float* input_ptr = input_contig.data_ptr<float>();
-
-    // Pull out the PackedGemmMatrixFP16 instance from the owning tensor
-    auto& packed_param_struct =
-        cpp_custom_type_hack::cast<PackedLinearWeightFp16>(packed_weight);
-    auto& packed_weight_fp16 = *packed_param_struct.w;
-    auto& bias = packed_param_struct.bias;
-
-    TORCH_CHECK(input.size(input.dim() - 1) == packed_weight_fp16.numRows())
-    TORCH_CHECK(input.dim() >= 2);
-
-    const int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
-    const int64_t N = packed_weight_fp16.numCols();
-    std::vector<int64_t> output_size = input.sizes().vec();
-    output_size.back() = N;
-    Tensor output = at::empty(output_size, input.options().dtype(at::kFloat));
-
-    // Call the fp16 gemm interface
-    fbgemm::cblas_gemm_compute(
-        fbgemm::matrix_op_t::NoTranspose,
-        M,
-        input_ptr,
-        packed_weight_fp16,
-        0.0f,
-        output.data_ptr<float>());
-
-    // Add bias term
-    if (bias.has_value()) {
-      TORCH_CHECK(bias->dim() == 1);
-      output.add_(*bias);
-    }
-
-    return output;
+    TORCH_INTERNAL_ASSERT(!ReluFused);
+    return packed_weight->apply_dynamic(std::move(input));
   }
 #else // USE_FBGEMM
   static at::Tensor run(
       at::Tensor /* input */,
-      at::Tensor /* packed_weight */) {
+      const c10::intrusive_ptr<LinearPackedParamsBase>& /* packed_weight */) {
     // We make a strong guarantee that models using these operators will have
     // the same numerics across different machines. Therefore, we do not provide
     // a fallback path and rather fail loudly if we cannot run FBGEMM.
diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp
index 2af0625..d819f0b 100644
--- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp
@@ -1,24 +1,237 @@
 #include <ATen/ATen.h>
-#include <torch/library.h>
 #include <ATen/cpp_custom_type_hack.h>
 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
 #include <ATen/native/quantized/cpu/init_qnnpack.h>
+#include <ATen/native/quantized/cpu/packed_params.h>
 #include <ATen/native/quantized/cpu/qnnpack_utils.h>
 #include <ATen/quantized/Quantizer.h>
+#include <torch/custom_class.h>
+#include <torch/library.h>
 #include <algorithm>
 #include <vector>
 
-namespace caffe2 {
+torch::jit::class_<LinearPackedParamsBase> register_linear_params();
+
 #ifdef USE_FBGEMM
-// Required for cpp_custom_type_hack to work
-CAFFE_KNOWN_TYPE(PackedLinearWeight);
-CAFFE_KNOWN_TYPE(PackedLinearWeightFp16);
+namespace {
+// Calculate the column offsets.
+// Note this includes the sum of the columns as well as the scalar term
+// B_zero_point * K, whereas the row_offsets created by
+// PackAWithQuantRowOffset is only the sum of the A rows.
+void calc_col_offsets_transpose(
+    int K,
+    int N,
+    const int8_t* Bint8,
+    int32_t* B_zero_point,
+    int32_t* col_offsets,
+    c10::QScheme qtype) {
+  for (size_t i = 0; i < N; ++i) {
+    int32_t sum = 0;
+    for (size_t j = 0; j < K; ++j) {
+      sum += Bint8[i * K + j];
+    }
+    if (qtype == c10::kPerTensorAffine) {
+      col_offsets[i] = sum - B_zero_point[0] * K;
+    } else {
+      col_offsets[i] = sum - B_zero_point[i] * K;
+    }
+  }
+}
+} // namespace
+
+c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeight::prepack(
+    at::Tensor weight,
+    c10::optional<at::Tensor> bias) {
+  TORCH_CHECK(
+      weight.dim() == 2,
+      "The weight tensor for quantized::linear_prepack (fbgemm) should"
+      " be 2-dimensional.");
+
+  auto N = weight.size(0);
+  auto K = weight.size(1);
+
+  // TODO: contiguous is called for further JIT optimizations.
+  auto weight_contig = weight.contiguous();
+  const auto qtype = weight.qscheme();
+  std::vector<int32_t> weight_zero_points_int32(1, 0);
+  if (qtype == c10::kPerTensorAffine) {
+    weight_zero_points_int32[0] = weight.q_zero_point();
+  } else if (qtype == c10::kPerChannelAffine) {
+    weight_zero_points_int32.resize(N, 0);
+    for (int i = 0; i < N; ++i) {
+      weight_zero_points_int32[i] =
+          weight.q_per_channel_zero_points()[i].item<int32_t>();
+    }
+  }
+  std::vector<float> weight_scales_float(1, 0.0);
+  if (qtype == c10::kPerTensorAffine) {
+    weight_scales_float[0] = weight.q_scale();
+  } else if (qtype == c10::kPerChannelAffine) {
+    weight_scales_float.resize(N, 0.0);
+    for (int i = 0; i < N; ++i) {
+      weight_scales_float[i] = weight.q_per_channel_scales()[i].item<float>();
+    }
+  }
+
+  int8_t* weight_ptr_int8 =
+      reinterpret_cast<int8_t*>(weight_contig.data_ptr<c10::qint8>());
+
+  std::vector<int32_t> col_offsets(N);
+  calc_col_offsets_transpose(
+      /*K=*/K,
+      /*N=*/N,
+      /*Bint8=*/weight_ptr_int8,
+      /*B_zero_point=*/weight_zero_points_int32.data(),
+      /*col_offsets=*/col_offsets.data(),
+      /*qtype=*/qtype);
+
+  c10::optional<at::Tensor> bias_contig;
+  if (bias.has_value()) {
+    at::Tensor bias_vec = bias.value();
+    TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)");
+    TORCH_CHECK(
+        bias_vec.size(0) == N,
+        "bias should have N elements: " + std::to_string(N));
+    bias_contig = bias->contiguous();
+  }
+  auto ret_ptr = c10::make_intrusive<PackedLinearWeight>(
+      std::make_unique<fbgemm::PackBMatrix<int8_t>>(
+          /*trans=*/fbgemm::matrix_op_t::Transpose,
+          /*nRow=*/K,
+          /*nCol=*/N,
+          /*smat=*/weight_ptr_int8,
+          /*ld=*/K,
+          /*pmat=*/nullptr, // PackBMatrix manages ownership of pmat
+          /*groups=*/1),
+      bias_contig,
+      col_offsets,
+      weight_scales_float,
+      weight_zero_points_int32,
+      qtype);
+  return ret_ptr;
+}
 #endif // USE_FBGEMM
+
 #ifdef USE_PYTORCH_QNNPACK
-// Required for cpp_custom_type_hack to work
-CAFFE_KNOWN_TYPE(PackedLinearWeightsQnnp);
+c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightsQnnp::prepack(
+    at::Tensor weight,
+    c10::optional<at::Tensor> bias_in) {
+  TORCH_CHECK(
+      weight.dim() == 2,
+      "quantized::linear_prepack (qnnpack): Weight tensor rank should be == 2");
+  TORCH_CHECK(
+      weight.qscheme() == c10::kPerTensorAffine,
+      "quantized::linear_prepack (qnnpack) only supports Per Tensor Quantization Scheme")
+
+  int64_t rows_w = weight.size(0);
+  at::Tensor bias_fp32;
+  if (bias_in.has_value()) {
+    bias_fp32 = bias_in.value();
+  } else {
+    bias_fp32 = at::zeros(rows_w, weight.options().dtype(at::kFloat));
+  }
+  TORCH_CHECK(
+      !bias_fp32.defined() ||
+          (bias_fp32.ndimension() == 1 && bias_fp32.size(0) == rows_w),
+      "quantized::linear_prepack (qnnpack): Given weight of size ",
+      weight.sizes(),
+      ", expected bias to be 1-dimensional with ",
+      rows_w,
+      " elements",
+      ", but got bias of size ",
+      bias_fp32.sizes(),
+      " instead");
+
+  at::Tensor weight_contig = weight.contiguous();
+  auto weight_zp = weight.q_zero_point();
+
+  at::native::initQNNPACK();
+
+  // We set the pre-packed linear weights to nullptr below as we call pre-pack
+  // during the first invocation of operator run. Refer to qlinear.cpp for more
+  // details. TODO Update to actually call pre-pack here once bias is removed
+  // from pre-packing step.
+  auto wt_ptr = c10::make_intrusive<PackedLinearWeightsQnnp>(
+      nullptr,
+      weight_contig, /* int8_t weight */
+      bias_fp32.contiguous(), /* fp32 bias */
+      c10::nullopt, /* input_scale */
+      weight.q_scale(),
+      weight_zp);
+  return wt_ptr;
+}
 #endif // USE_PYTORCH_QNNPACK
-} // namespace caffe2
+
+#ifdef USE_FBGEMM
+namespace {
+float RawUint16ToFp16(unsigned short value) {
+  // Convert raw 16 bits half precision floating point number
+  // to single precision floating point number.
+  const unsigned short sign_bits = value >> 15;
+  const unsigned short exponent_bits = value >> 10 & 0x1f;
+  const unsigned short significand_bits = value & 0x3ff;
+
+  const float sign = sign_bits ? -1 : 1;
+  const float significand =
+      1 + significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10;
+  const float exponent = exponent_bits - 0xf;
+
+  return sign * std::ldexp(significand, exponent);
+}
+
+template <typename T>
+bool CheckAndSaturate(T max_val, T* element) {
+  if (*element > max_val) {
+    *element = max_val;
+    return true;
+  }
+  if (*element < -max_val) {
+    *element = -max_val;
+    return true;
+  }
+  return false;
+}
+
+// The range for using FP16 quantization of weights requires that the elements
+// should be in the range of [5.96e-8, 65504]. If it is out of range, then the
+// number will be saturated to max or min representable values by FP16.
+void HandleWeightsSaturation(int64_t N, float* weight) {
+  const float kFp16Max = RawUint16ToFp16(0x7BFF);
+  bool found_out_of_range = false;
+  for (int64_t i = 0; i < N; ++i) {
+    if (CheckAndSaturate<float>(kFp16Max, weight + i)) {
+      found_out_of_range = true;
+    }
+  }
+  if (found_out_of_range) {
+    TORCH_WARN("FOUND weight out of range ");
+  }
+}
+} // namespace
+
+c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightFp16::prepack(
+    at::Tensor weight,
+    c10::optional<at::Tensor> bias) {
+  const int64_t K = weight.size(1);
+  const int64_t N = weight.size(0);
+  at::Tensor weight_contig = weight.contiguous();
+  float* weight_contig_ptr = weight_contig.data_ptr<float>();
+  HandleWeightsSaturation(K * N, weight_contig_ptr);
+
+  // TODO(mingzhe09088):
+  // Consider using a functor here in PackedGemmMatrixFP16
+  // Comments from (XQ): Not entirely sure this make_unique is safe.
+  // make_unique is created with regular "new", and freed through
+  // TypeMetaData::deleteFn in this function. This is perfectly fine if the
+  // tensors are created and freed within this translation unit. It might be
+  // very problematic if that tensor flows across dll boundaries.
+  auto ptr = c10::make_intrusive<PackedLinearWeightFp16>(
+      std::make_unique<fbgemm::PackedGemmMatrixFP16>(
+          fbgemm::matrix_op_t::Transpose, K, N, 1, weight_contig_ptr),
+      bias);
+  return ptr;
+}
+#endif // USE_FBGEMM
 
 namespace at {
 namespace native {
@@ -26,164 +239,20 @@
 
 class QLinearPackWeightInt8 final {
  public:
-#ifdef USE_FBGEMM
-  // Calculate the column offsets.
-  // Note this includes the sum of the columns as well as the scalar term
-  // B_zero_point * K, whereas the row_offsets created by
-  // PackAWithQuantRowOffset is only the sum of the A rows.
-  static void calc_col_offsets_transpose(
-      int K,
-      int N,
-      const int8_t* Bint8,
-      int32_t* B_zero_point,
-      int32_t* col_offsets,
-      c10::QScheme qtype) {
-    for (size_t i = 0; i < N; ++i) {
-      int32_t sum = 0;
-      for (size_t j = 0; j < K; ++j) {
-        sum += Bint8[i * K + j];
-      }
-      if (qtype == kPerTensorAffine) {
-        col_offsets[i] = sum - B_zero_point[0] * K;
-      } else {
-        col_offsets[i] = sum - B_zero_point[i] * K;
-      }
-    }
-  }
-  static at::Tensor fbgemm_linear_prepack(
+  static c10::intrusive_ptr<LinearPackedParamsBase> run(
       at::Tensor weight,
       c10::optional<Tensor> bias) {
-    TORCH_CHECK(
-        weight.dim() == 2,
-        "The weight tensor for quantized::linear_prepack (fbgemm) should"
-        " be 2-dimensional.");
-
-    auto N = weight.size(0);
-    auto K = weight.size(1);
-
-    // TODO: contiguous is called for further JIT optimizations.
-    auto weight_contig = weight.contiguous();
-    const auto qtype = weight.qscheme();
-    std::vector<int32_t> weight_zero_points_int32(1, 0);
-    if (qtype == kPerTensorAffine) {
-      weight_zero_points_int32[0] = weight.q_zero_point();
-    } else if (qtype == kPerChannelAffine) {
-      weight_zero_points_int32.resize(N, 0);
-      for (int i = 0; i < N; ++i) {
-        weight_zero_points_int32[i] =
-            weight.q_per_channel_zero_points()[i].item<int32_t>();
-      }
-    }
-    std::vector<float> weight_scales_float(1, 0.0);
-    if (qtype == kPerTensorAffine) {
-      weight_scales_float[0] = weight.q_scale();
-    } else if (qtype == kPerChannelAffine) {
-      weight_scales_float.resize(N, 0.0);
-      for (int i = 0; i < N; ++i) {
-        weight_scales_float[i] = weight.q_per_channel_scales()[i].item<float>();
-      }
-    }
-
-    int8_t* weight_ptr_int8 =
-        reinterpret_cast<int8_t*>(weight_contig.data_ptr<c10::qint8>());
-
-    std::vector<int32_t> col_offsets(N);
-    calc_col_offsets_transpose(
-        /*K=*/K,
-        /*N=*/N,
-        /*Bint8=*/weight_ptr_int8,
-        /*B_zero_point=*/weight_zero_points_int32.data(),
-        /*col_offsets=*/col_offsets.data(),
-        /*qtype=*/qtype);
-
-    c10::optional<at::Tensor> bias_contig;
-    if (bias.has_value()) {
-      Tensor bias_vec = bias.value();
-      TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)");
-      TORCH_CHECK(
-          bias_vec.size(0) == N,
-          "bias should have N elements: " + std::to_string(N));
-      bias_contig = bias->contiguous();
-    }
-    auto ret_ptr = std::make_unique<PackedLinearWeight>(PackedLinearWeight{
-        std::make_unique<fbgemm::PackBMatrix<int8_t>>(
-            /*trans=*/fbgemm::matrix_op_t::Transpose,
-            /*nRow=*/K,
-            /*nCol=*/N,
-            /*smat=*/weight_ptr_int8,
-            /*ld=*/K,
-            /*pmat=*/nullptr, // PackBMatrix manages ownership of pmat
-            /*groups=*/1),
-        bias_contig,
-        col_offsets,
-        weight_scales_float,
-        weight_zero_points_int32,
-        qtype});
-
-    // TODO: we will need to replace this with torchscript classes at a later
-    // point.
-    return cpp_custom_type_hack::create(std::move(ret_ptr), weight.options());
-  }
-#endif
-#ifdef USE_PYTORCH_QNNPACK
-  static at::Tensor qnnpack_linear_prepack(
-      at::Tensor weight,
-      c10::optional<Tensor> bias_in) {
-    TORCH_CHECK(
-        weight.dim() == 2,
-        "quantized::linear_prepack (qnnpack): Weight tensor rank should be == 2");
-    TORCH_CHECK(
-        weight.qscheme() == kPerTensorAffine,
-        "quantized::linear_prepack (qnnpack) only supports Per Tensor Quantization Scheme")
-
-    int64_t rows_w = weight.size(0);
-    Tensor bias_fp32;
-    if (bias_in.has_value()) {
-      bias_fp32 = bias_in.value();
-    } else {
-      bias_fp32 = at::zeros(rows_w, weight.options().dtype(at::kFloat));
-    }
-    TORCH_CHECK(
-        !bias_fp32.defined() || (bias_fp32.ndimension() == 1 && bias_fp32.size(0) == rows_w),
-        "quantized::linear_prepack (qnnpack): Given weight of size ",
-        weight.sizes(),
-        ", expected bias to be 1-dimensional with ",
-        rows_w,
-        " elements",
-        ", but got bias of size ",
-        bias_fp32.sizes(),
-        " instead");
-
-    Tensor weight_contig = weight.contiguous();
-    auto weight_zp = weight.q_zero_point();
-
-    initQNNPACK();
-
-    // We set the pre-packed linear weights to nullptr below as we call pre-pack
-    // during the first invocation of operator run. Refer to qlinear.cpp for more
-    // details. TODO Update to actually call pre-pack here once bias is removed
-    // from pre-packing step.
-    auto wt_ptr = std::make_unique<PackedLinearWeightsQnnp>(
-        PackedLinearWeightsQnnp{nullptr,
-                                weight_contig, /* int8_t weight */
-                                bias_fp32.contiguous(), /* fp32 bias */
-                                c10::nullopt, /* input_scale */
-                                weight.q_scale(),
-                                weight_zp});
-    return cpp_custom_type_hack::create(std::move(wt_ptr), weight.options());
-  }
-#endif
-  static at::Tensor run(at::Tensor weight, c10::optional<Tensor> bias) {
     auto& ctx = at::globalContext();
 
 #ifdef USE_FBGEMM
     if (ctx.qEngine() == at::QEngine::FBGEMM) {
-      return fbgemm_linear_prepack(weight, bias);
+      return PackedLinearWeight::prepack(std::move(weight), std::move(bias));
     }
 #endif
 #ifdef USE_PYTORCH_QNNPACK
     if (ctx.qEngine() == at::QEngine::QNNPACK) {
-      return qnnpack_linear_prepack(weight, bias);
+      return PackedLinearWeightsQnnp::prepack(
+          std::move(weight), std::move(bias));
     }
 #endif
     TORCH_CHECK(
@@ -194,51 +263,23 @@
 };
 
 class QLinearPackWeightFp16 final {
-public:
-#ifdef USE_FBGEMM
-  static at::Tensor fbgemm_linear_prepack_fp16(
+ public:
+  static c10::intrusive_ptr<LinearPackedParamsBase> run(
       at::Tensor weight,
       c10::optional<Tensor> bias) {
-    const int64_t K = weight.size(1);
-    const int64_t N = weight.size(0);
-    Tensor weight_contig = weight.contiguous();
-    float* weight_contig_ptr = weight_contig.data_ptr<float>();
-    HandleWeightsSaturation(K * N, weight_contig_ptr);
-
-    // TODO(mingzhe09088):
-    // Consider using a functor here in PackedGemmMatrixFP16
-    // Comments from (XQ): Not entirely sure this make_unique is safe.
-    // make_unique is created with regular "new", and freed through
-    // TypeMetaData::deleteFn in this function. This is perfectly fine if the
-    // tensors are created and freed within this translation unit. It might be
-    // very problematic if that tensor flows across dll boundaries.
-    auto ptr = std::make_unique<PackedLinearWeightFp16>(PackedLinearWeightFp16{
-        std::make_unique<fbgemm::PackedGemmMatrixFP16>(
-            fbgemm::matrix_op_t::Transpose, K, N, 1, weight_contig_ptr),
-        bias});
-    return cpp_custom_type_hack::create(std::move(ptr), weight.options());
-  }
-#endif
-#ifdef USE_PYTORCH_QNNPACK
-  static at::Tensor qnnpack_linear_prepack_fp16(
-      at::Tensor weight,
-      c10::optional<Tensor> bias_in) {
-    TORCH_CHECK(
-        false,
-        "quantized::linear_prepack_fp16 is currently "
-        "not supported by QNNPACK");
-  }
-#endif // USE_PYTORCH_QNNPACK
-  static at::Tensor run(at::Tensor weight, c10::optional<Tensor> bias) {
     auto& ctx = at::globalContext();
 #ifdef USE_FBGEMM
     if (ctx.qEngine() == at::QEngine::FBGEMM) {
-      return fbgemm_linear_prepack_fp16(weight, bias);
+      return PackedLinearWeightFp16::prepack(
+          std::move(weight), std::move(bias));
     }
 #endif // USE_FBGEMM
 #ifdef USE_PYTORCH_QNNPACK
     if (ctx.qEngine() == at::QEngine::QNNPACK) {
-      return qnnpack_linear_prepack_fp16(weight, bias);
+      TORCH_CHECK(
+          false,
+          "quantized::linear_prepack_fp16 is currently "
+          "not supported by QNNPACK");
     }
 #endif // USE_PYTORCH_QNNPACK
     TORCH_CHECK(
@@ -246,61 +287,79 @@
         "Didn't find engine for operation quantized::linear_prepack_fp16 ",
         toString(ctx.qEngine()));
   }
+};
 
- private:
+class QLinearPackWeightInt8Legacy final {
+ public:
+  static Tensor run(at::Tensor weight, c10::optional<Tensor> bias) {
+    auto& ctx = at::globalContext();
+    auto options = weight.options();
+
 #ifdef USE_FBGEMM
-  static float RawUint16ToFp16(unsigned short value) {
-    // Convert raw 16 bits half precision floating point number
-    // to single precision floating point number.
-    const unsigned short sign_bits = value >> 15;
-    const unsigned short exponent_bits = value >> 10 & 0x1f;
-    const unsigned short significand_bits = value & 0x3ff;
-
-    const float sign = sign_bits ? -1 : 1;
-    const float significand = 1 +
-        significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10;
-    const float exponent = exponent_bits - 0xf;
-
-    return sign * std::ldexp(significand, exponent);
-  }
-
-  template <typename T>
-  static bool CheckAndSaturate(T max_val, T* element) {
-    if (*element > max_val) {
-      *element = max_val;
-      return true;
+    if (ctx.qEngine() == at::QEngine::FBGEMM) {
+      auto prepacked =
+          PackedLinearWeight::prepack(std::move(weight), std::move(bias));
+      auto wrapped =
+          std::make_unique<c10::intrusive_ptr<LinearPackedParamsBase>>(
+              std::move(prepacked));
+      return cpp_custom_type_hack::create(std::move(wrapped), options);
     }
-    if (*element < -max_val) {
-      *element = -max_val;
-      return true;
-    }
-    return false;
-  }
-
-  // The range for using FP16 quantization of weights requires that the elements
-  // should be in the range of [5.96e-8, 65504]. If it is out of range, then the
-  // number will be saturated to max or min representable values by FP16.
-  static void HandleWeightsSaturation(int64_t N, float* weight) {
-    const float kFp16Max = RawUint16ToFp16(0x7BFF);
-    bool found_out_of_range = false;
-    for (int64_t i = 0; i < N; ++i) {
-      if (CheckAndSaturate<float>(kFp16Max, weight + i)) {
-        found_out_of_range = true;
-      }
-    }
-    if (found_out_of_range) {
-      TORCH_WARN("FOUND weight out of range ");
-    }
-  }
 #endif // USE_FBGEMM
+#ifdef USE_PYTORCH_QNNPACK
+    if (ctx.qEngine() == at::QEngine::QNNPACK) {
+      auto prepacked =
+          PackedLinearWeightsQnnp::prepack(std::move(weight), std::move(bias));
+      auto wrapped =
+          std::make_unique<c10::intrusive_ptr<LinearPackedParamsBase>>(
+              std::move(prepacked));
+      return cpp_custom_type_hack::create(std::move(wrapped), options);
+    }
+#endif // USE_PYTORCH_QNNPACK
+    TORCH_CHECK(
+        false,
+        "Didn't find engine for operation quantized::linear_prepack ",
+        toString(ctx.qEngine()));
+  }
+};
+
+class QLinearPackWeightFp16Legacy final {
+ public:
+  static Tensor run(at::Tensor weight, c10::optional<Tensor> bias) {
+    auto& ctx = at::globalContext();
+    auto options = weight.options();
+#ifdef USE_FBGEMM
+    if (ctx.qEngine() == at::QEngine::FBGEMM) {
+      auto prepacked =
+          PackedLinearWeightFp16::prepack(std::move(weight), std::move(bias));
+      auto wrapped =
+          std::make_unique<c10::intrusive_ptr<LinearPackedParamsBase>>(
+              std::move(prepacked));
+      return cpp_custom_type_hack::create(std::move(wrapped), options);
+    }
+#endif // USE_FBGEMM
+#ifdef USE_PYTORCH_QNNPACK
+    if (ctx.qEngine() == at::QEngine::QNNPACK) {
+      TORCH_CHECK(
+          false,
+          "quantized::linear_prepack_fp16 is currently "
+          "not supported by QNNPACK");
+    }
+#endif // USE_PYTORCH_QNNPACK
+    TORCH_CHECK(
+        false,
+        "Didn't find engine for operation quantized::linear_prepack_fp16 ",
+        toString(ctx.qEngine()));
+  }
 };
 
 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
   m.impl("linear_prepack", QLinearPackWeightInt8::run);
+  m.impl("linear_prepack_legacy", QLinearPackWeightInt8Legacy::run);
 }
 
 TORCH_LIBRARY_IMPL(quantized, CPU, m) {
   m.impl("linear_prepack_fp16", QLinearPackWeightFp16::run);
+  m.impl("linear_prepack_fp16_legacy", QLinearPackWeightFp16Legacy::run);
 }
 
 TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) {
@@ -309,6 +368,7 @@
 
 TORCH_LIBRARY_IMPL(_quantized, CPU, m) {
   m.impl("linear_prepack_fp16", QLinearPackWeightFp16::run);
+  m.impl("linear_prepack_fp16_legacy", QLinearPackWeightFp16Legacy::run);
 }
 
 } // namespace
diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp
index 191a060..634b8d5 100644
--- a/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp
@@ -1,8 +1,78 @@
 #include <ATen/ATen.h>
-#include <torch/library.h>
 #include <ATen/cpp_custom_type_hack.h>
 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
+#include <ATen/native/quantized/cpu/packed_params.h>
 #include <ATen/native/quantized/cpu/qnnpack_utils.h>
+#include <torch/custom_class.h>
+#include <torch/library.h>
+
+torch::jit::class_<LinearPackedParamsBase> register_linear_params();
+
+#ifdef USE_FBGEMM
+std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedLinearWeight::unpack() {
+  auto packB = w.get();
+
+  int64_t N = static_cast<int64_t>(packB->numCols());
+  int64_t K = static_cast<int64_t>(packB->numRows());
+
+  at::Tensor weight_origin;
+  if (q_scheme == c10::kPerTensorAffine) {
+    weight_origin = at::_empty_affine_quantized(
+        {N, K}, at::device(c10::kCPU).dtype(c10::kQInt8), w_scale[0], w_zp[0]);
+  } else if (q_scheme == c10::kPerChannelAffine) {
+    auto scales = at::from_blob(
+        w_scale.data(), w_scale.size(), device(c10::kCPU).dtype(c10::kFloat));
+    auto zero_points = at::from_blob(
+        w_zp.data(), w_zp.size(), device(c10::kCPU).dtype(c10::kInt));
+
+    weight_origin = at::_empty_per_channel_affine_quantized(
+        {N, K},
+        scales.toType(c10::kDouble),
+        zero_points.toType(c10::kLong),
+        0, // The output channel axis is 0
+        device(c10::kCPU).dtype(c10::kQInt8));
+  }
+
+  int8_t* weight_ptr_int8 =
+      reinterpret_cast<int8_t*>(weight_origin.data_ptr<c10::qint8>());
+
+  // packB->printPackedMatrix("packedB inside fbgemm_unpack
+  // (QLinearUnpackWeightInt8): ");
+  packB->unpack(weight_ptr_int8);
+
+  return std::tuple<at::Tensor, c10::optional<at::Tensor>>(
+      weight_origin, bias_);
+}
+#endif // USE_FBGEMM
+
+#ifdef USE_PYTORCH_QNNPACK
+std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedLinearWeightsQnnp::
+    unpack() {
+  TORCH_CHECK(
+      orig_weight.defined(),
+      "Cannot unpack weights. "
+      "Call at::globalContext()::setReleaseOriginalWeights(false) before packing or loading to enable unpacking.");
+  return std::tuple<at::Tensor, c10::optional<at::Tensor>>(orig_weight, bias_);
+}
+#endif // USE_PYTORCH_QNNPACK
+
+#ifdef USE_FBGEMM
+std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedLinearWeightFp16::
+    unpack() {
+  auto& packed_weight_ptr = w;
+
+  auto nrows = packed_weight_ptr->numRows();
+  auto ncols = packed_weight_ptr->numCols();
+
+  at::Tensor unpacked_weight =
+      at::empty({ncols, nrows}, at::kHalf, c10::MemoryFormat::Contiguous);
+  packed_weight_ptr->unpack(
+      static_cast<fbgemm::float16*>(unpacked_weight.data_ptr()),
+      fbgemm::matrix_op_t::Transpose);
+
+  return std::make_tuple(unpacked_weight.to(at::kFloat), bias_);
+}
+#endif // USE_FBGEMM
 
 namespace at {
 namespace native {
@@ -10,139 +80,68 @@
 
 class QLinearUnpackWeightInt8 final {
  public:
-#ifdef USE_FBGEMM
-  static std::tuple<at::Tensor, c10::optional<Tensor>> fbgemm_linear_unpack(
-      at::Tensor packed_weight) {
-    // Pull out the PackBMatrix instance from the owning tensor.
-    auto& pack_ptr =
-        cpp_custom_type_hack::cast<PackedLinearWeight>(packed_weight);
-    auto packB = pack_ptr.w.get();
-
-    int64_t N = static_cast<int64_t>(packB->numCols());
-    int64_t K = static_cast<int64_t>(packB->numRows());
-
-    Tensor weight_origin;
-    if (pack_ptr.q_scheme == kPerTensorAffine) {
-      weight_origin = _empty_affine_quantized(
-          {N, K},
-          at::device(kCPU).dtype(kQInt8),
-          pack_ptr.w_scale[0],
-          pack_ptr.w_zp[0]);
-    } else if (pack_ptr.q_scheme == kPerChannelAffine) {
-      auto scales = from_blob(
-          pack_ptr.w_scale.data(),
-          pack_ptr.w_scale.size(),
-          device(kCPU).dtype(kFloat));
-      auto zero_points = from_blob(
-          pack_ptr.w_zp.data(), pack_ptr.w_zp.size(), device(kCPU).dtype(kInt));
-
-      weight_origin = _empty_per_channel_affine_quantized(
-          {N, K},
-          scales.toType(kDouble),
-          zero_points.toType(kLong),
-          0, // The output channel axis is 0
-          device(kCPU).dtype(kQInt8));
-    }
-
-    int8_t* weight_ptr_int8 =
-        reinterpret_cast<int8_t*>(weight_origin.data_ptr<c10::qint8>());
-
-    // packB->printPackedMatrix("packedB inside fbgemm_unpack
-    // (QLinearUnpackWeightInt8): ");
-    packB->unpack(weight_ptr_int8);
-
-    return std::tuple<at::Tensor, c10::optional<Tensor>>(
-        weight_origin, pack_ptr.bias);
-  }
-#endif // USE_FBGEMM
-#ifdef USE_PYTORCH_QNNPACK
-  static std::tuple<at::Tensor, c10::optional<Tensor>> qnnpack_linear_unpack(
-      at::Tensor packed_weight) {
-    auto& pack_ptr =
-        cpp_custom_type_hack::cast<PackedLinearWeightsQnnp>(packed_weight);
-    TORCH_CHECK(
-        pack_ptr.orig_weight.defined(),
-        "Cannot unpack weights. "
-        "Call at::globalContext()::setReleaseOriginalWeights(false) before packing or loading to enable unpacking.");
-    return std::tuple<at::Tensor, c10::optional<Tensor>>(
-        pack_ptr.orig_weight, pack_ptr.bias);
-  }
-#endif // USE_PYTORCH_QNNPACK
   static std::tuple<at::Tensor, c10::optional<Tensor>> run(
-      at::Tensor packed_weight) {
-    auto& ctx = at::globalContext();
-
-#ifdef USE_FBGEMM
-    if (ctx.qEngine() == at::QEngine::FBGEMM) {
-      return fbgemm_linear_unpack(packed_weight);
-    }
-#endif
-#ifdef USE_PYTORCH_QNNPACK
-    if (ctx.qEngine() == at::QEngine::QNNPACK) {
-      return qnnpack_linear_unpack(packed_weight);
-    }
-#endif
-    TORCH_CHECK(
-        false,
-        "Didn't find engine for operation quantized::linear_unpack ",
-        toString(ctx.qEngine()));
+      const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight) {
+    return packed_weight->unpack();
   }
 };
 
 class QLinearUnpackWeightFp16 final {
  public:
-#ifdef USE_FBGEMM
-  static std::tuple<at::Tensor, c10::optional<Tensor>> fbgemm_linear_unpack(
-      at::Tensor packed_weight) {
-    // Pull out the PackBMatrix instance from the owning tensor.
-    auto& packed_struct =
-        cpp_custom_type_hack::cast<PackedLinearWeightFp16>(packed_weight);
-    auto& packed_weight_ptr = packed_struct.w;
-    auto& bias = packed_struct.bias;
-
-    auto nrows = packed_weight_ptr->numRows();
-    auto ncols = packed_weight_ptr->numCols();
-
-    at::Tensor unpacked_weight =
-        at::empty({ncols, nrows}, at::kHalf, MemoryFormat::Contiguous);
-    packed_weight_ptr->unpack(
-        static_cast<fbgemm::float16*>(unpacked_weight.data_ptr()),
-        fbgemm::matrix_op_t::Transpose);
-
-    return std::make_tuple(unpacked_weight.to(at::kFloat), bias);
-  }
-#endif // USE_FBGEMM
-#ifdef USE_PYTORCH_QNNPACK
-  static std::tuple<at::Tensor, c10::optional<Tensor>> qnnpack_linear_unpack(
-      at::Tensor packed_weight) {
-    TORCH_CHECK(
-        false,
-        "quantized::linear_unpack_fp16 is currently "
-        "not supported by QNNPACK");
-  }
-#endif // USE_PYTORCH_QNNPACK
   static std::tuple<at::Tensor, c10::optional<Tensor>> run(
-      at::Tensor packed_weight) {
+      const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight) {
     auto& ctx = at::globalContext();
 
-#ifdef USE_FBGEMM
-    if (ctx.qEngine() == at::QEngine::FBGEMM) {
-      return fbgemm_linear_unpack(packed_weight);
-    }
-#endif
-#ifdef USE_PYTORCH_QNNPACK
-    if (ctx.qEngine() == at::QEngine::QNNPACK) {
-      return qnnpack_linear_unpack(packed_weight);
-    }
-#endif
     TORCH_CHECK(
-        false,
-        "Didn't find engine for operation quantized::linear_unpack_fp16 ",
-        toString(ctx.qEngine()));
+        ctx.qEngine() != at::QEngine::QNNPACK,
+        "quantized::linear_unpack_fp16 is currently "
+        "not supported by QNNPACK");
+
+    return packed_weight->unpack();
+  }
+};
+
+class QLinearUnpackWeightInt8Legacy final {
+ public:
+  static std::tuple<at::Tensor, c10::optional<Tensor>> run(
+      const at::Tensor& packed_weight) {
+    TORCH_WARN_ONCE(
+        "quantized.linear_unpack(Tensor) is deprecated! Please "
+        "upgrade your model to use the newer quantized.linear_"
+        "unpack(LinearPackedParamsBase) overload");
+    return cpp_custom_type_hack::cast<
+               c10::intrusive_ptr<LinearPackedParamsBase>>(packed_weight)
+        ->unpack();
+  }
+};
+
+class QLinearUnpackWeightFp16Legacy final {
+ public:
+  static std::tuple<at::Tensor, c10::optional<Tensor>> run(
+      const at::Tensor& packed_weight) {
+    TORCH_WARN_ONCE(
+        "quantized.linear_unpack(Tensor) is deprecated! Please "
+        "upgrade your model to use the newer quantized.linear_"
+        "unpack(LinearPackedParamsBase) overload");
+    auto& ctx = at::globalContext();
+
+    TORCH_CHECK(
+        ctx.qEngine() != at::QEngine::QNNPACK,
+        "quantized::linear_unpack_fp16 is currently "
+        "not supported by QNNPACK");
+
+    return cpp_custom_type_hack::cast<
+               c10::intrusive_ptr<LinearPackedParamsBase>>(packed_weight)
+        ->unpack();
   }
 };
 
 TORCH_LIBRARY_IMPL(quantized, CPU, m) {
+  m.impl("linear_unpack.legacy", QLinearUnpackWeightInt8Legacy::run);
+  m.impl("linear_unpack_fp16.legacy", QLinearUnpackWeightFp16Legacy::run);
+}
+
+TORCH_LIBRARY_IMPL(quantized, CatchAll, m) {
   m.impl("linear_unpack", QLinearUnpackWeightInt8::run);
   m.impl("linear_unpack_fp16", QLinearUnpackWeightFp16::run);
 }
diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack_utils.h b/aten/src/ATen/native/quantized/cpu/qnnpack_utils.h
index 8f59937..6434bfc 100644
--- a/aten/src/ATen/native/quantized/cpu/qnnpack_utils.h
+++ b/aten/src/ATen/native/quantized/cpu/qnnpack_utils.h
@@ -5,6 +5,7 @@
 #include <qnnpack_func.h>
 
 #include <ATen/native/quantized/cpu/conv_packed_params.h>
+#include <ATen/native/quantized/cpu/packed_params.h>
 
 struct QnnpackOperatorDeleter {
   void operator()(pytorch_qnnp_operator_t op) {
@@ -22,13 +23,59 @@
 // input scale value changes then we requantize bias with the updated scale. For
 // inference we expect the graph to be static so the input scale should not
 // change across consecutive inference calls.
-struct PackedLinearWeightsQnnp {
+struct PackedLinearWeightsQnnp : public LinearPackedParamsBase {
+  PackedLinearWeightsQnnp(
+      std::unique_ptr<qnnpack::PackBMatrix> w,
+      at::Tensor orig_weight,
+      at::Tensor bias,
+      c10::optional<double> input_scale,
+      double w_scale,
+      int64_t w_zp)
+      : w(std::move(w)),
+        orig_weight(std::move(orig_weight)),
+        bias_(std::move(bias)),
+        input_scale(std::move(input_scale)),
+        w_scale(w_scale),
+        w_zp(w_zp) {}
+
   std::unique_ptr<qnnpack::PackBMatrix> w;
   at::Tensor orig_weight;
-  at::Tensor bias;
+  at::Tensor bias_;
   c10::optional<double> input_scale;
   double w_scale;
   int64_t w_zp;
+
+  at::Tensor apply(
+      at::Tensor input,
+      double output_scale,
+      int64_t output_zero_point) override;
+  at::Tensor apply_relu(
+      at::Tensor input,
+      double output_scale,
+      int64_t output_zero_point) override;
+
+  at::Tensor apply_dynamic(at::Tensor input) override;
+  at::Tensor apply_dynamic_relu(at::Tensor input) override;
+
+  std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
+
+  c10::optional<at::Tensor> bias() override {
+    return bias_;
+  }
+
+  static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
+      at::Tensor weight,
+      c10::optional<at::Tensor> bias);
+
+ private:
+  template <bool ReluFused>
+  at::Tensor apply_impl(
+      at::Tensor input,
+      double output_scale,
+      int64_t output_zero_point);
+
+  template <bool ReluFused>
+  at::Tensor apply_dynamic_impl(at::Tensor input);
 };
 
 template <int kSpatialDim = 2>
diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp
index ba5f201..be8da83 100644
--- a/aten/src/ATen/native/quantized/library.cpp
+++ b/aten/src/ATen/native/quantized/library.cpp
@@ -1,8 +1,11 @@
 #include <torch/library.h>
 
 #include <ATen/native/quantized/cpu/conv_packed_params.h>
+#include <ATen/native/quantized/cpu/packed_params.h>
 #include <torch/custom_class.h>
 
+torch::jit::class_<LinearPackedParamsBase> register_linear_params();
+
 template <int kSpatialDim = 2>
 torch::jit::class_<ConvPackedParamsBase<kSpatialDim>> register_conv_params();
 
@@ -10,6 +13,7 @@
 extern template torch::jit::class_<ConvPackedParamsBase<3>> register_conv_params<3>();
 
 TORCH_LIBRARY(quantized, m) {
+  register_linear_params();
   register_conv_params<2>();
   register_conv_params<3>();
 
@@ -58,15 +62,31 @@
   m.def("group_norm(Tensor input, int num_groups, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> Tensor");
   m.def("instance_norm(Tensor input, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> Tensor");
   m.def("layer_norm(Tensor input, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> Tensor");
-  m.def("linear(Tensor X, Tensor W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y");
-  m.def("linear_relu(Tensor X, Tensor W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y");
-  m.def("linear_dynamic(Tensor X, Tensor W_prepack) -> Tensor Y");
-  m.def("linear_relu_dynamic(Tensor X, Tensor W_prepack) -> Tensor Y");
-  m.def("linear_dynamic_fp16(Tensor X, Tensor W_prepack) -> Tensor Y");
-  m.def("linear_prepack(Tensor W, Tensor? B=None) -> Tensor W_prepack");
-  m.def("linear_prepack_fp16(Tensor W, Tensor? B=None) -> Tensor W_prepack");
-  m.def("linear_unpack(Tensor W_prepack) -> (Tensor W_origin, Tensor? B_origin)");
-  m.def("linear_unpack_fp16(Tensor W_prepack) -> (Tensor W_origin, Tensor? B_origin)");
+  m.def(
+      "linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y");
+  m.def(
+      "linear_relu(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y");
+  m.def(
+      "linear_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y");
+  m.def(
+      "linear_relu_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y");
+  m.def(
+      "linear_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y");
+  m.def(
+      "linear_prepack(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack");
+  m.def(
+      "linear_prepack_fp16(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack");
+  m.def("linear_prepack_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack");
+  m.def(
+      "linear_prepack_fp16_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack");
+  m.def(
+      "linear_unpack(__torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> (Tensor W_origin, Tensor? B_origin)");
+  m.def(
+      "linear_unpack_fp16(__torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> (Tensor W_origin, Tensor? B_origin)");
+  m.def(
+      "linear_unpack.legacy(Tensor W_prepack) -> (Tensor W_origin, Tensor? B_origin)");
+  m.def(
+      "linear_unpack_fp16.legacy(Tensor W_prepack) -> (Tensor W_origin, Tensor? B_origin)");
   m.def("mul(Tensor qa, Tensor qb, float scale, int zero_point)-> Tensor qc");
   m.def("mul_relu(Tensor qa, Tensor qb, float scale, int zero_point)-> Tensor qc");
   m.def("mul_out(Tensor qa, Tensor qb, Tensor(a!) out)-> Tensor(a!) out");
@@ -88,8 +108,15 @@
   m.def("conv2d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor");
   m.def("conv2d_relu(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor");
   m.def("conv2d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase");
-  m.def("linear(Tensor X, Tensor W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y");
-  m.def("linear_dynamic(Tensor X, Tensor W_prepack) -> Tensor Y");
-  m.def("linear_prepack(Tensor W, Tensor? B=None) -> Tensor W_prepack");
-  m.def("linear_prepack_fp16(Tensor W, Tensor? B=None) -> Tensor W_prepack");
+  m.def(
+      "linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y");
+  m.def(
+      "linear_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y");
+  m.def(
+      "linear_prepack(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack");
+  m.def(
+      "linear_prepack_fp16(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack");
+  m.def("linear_prepack_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack");
+  m.def(
+      "linear_prepack_fp16_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack");
 }
diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py
index 6faf113..0092e83 100644
--- a/test/backward_compatibility/check_backward_compatibility.py
+++ b/test/backward_compatibility/check_backward_compatibility.py
@@ -31,7 +31,6 @@
     ('aten::index_put_', datetime.date(2020, 4, 10)),
     ('aten::quantize_per_tensor', datetime.date(2020, 4, 15)),
     ('aten::requires_grad_', datetime.date(2020, 4, 30)),
-    ('quantized::batch_norm', datetime.date(2020, 4, 20)),
     ('aten::sizes', datetime.date(2020, 4, 30)),
     ('aten::strides', datetime.date(2020, 4, 30)),
     ('aten::backward', datetime.date(2020, 4, 30)),
@@ -72,6 +71,17 @@
     ('aten::dict', datetime.date(2020, 6, 30)),
     ('aten::tensor', datetime.date(2020, 6, 30)),
     ('aten::as_tensor', datetime.date(2020, 6, 30)),
+    ('quantized::linear_unpack_fp16', datetime.date(2020, 6, 1)),
+    ('quantized::linear_unpack', datetime.date(2020, 6, 1)),
+    ('quantized::linear_prepack_fp16', datetime.date(2020, 6, 1)),
+    ('quantized::linear_prepack', datetime.date(2020, 6, 1)),
+    ('quantized::linear_dynamic_fp16', datetime.date(2020, 6, 1)),
+    ('quantized::linear_relu_dynamic', datetime.date(2020, 6, 1)),
+    ('quantized::linear_dynamic', datetime.date(2020, 6, 1)),
+    ('quantized::linear_relu', datetime.date(2020, 6, 1)),
+    ('quantized::linear', datetime.date(2020, 6, 1)),
+    ('_aten::*', datetime.date(2020, 6, 1)),
+    ('_prim::*', datetime.date(2020, 6, 1)),
 ]
 
 
@@ -94,6 +104,15 @@
     ('quantized::conv3d_unpack', datetime.date(2020, 6, 1)),
     ('quantized::conv3d', datetime.date(2020, 6, 1)),
     ('quantized::conv3d_relu', datetime.date(2020, 6, 1)),
+    ('quantized::linear_unpack_fp16', datetime.date(2020, 6, 1)),
+    ('quantized::linear_unpack', datetime.date(2020, 6, 1)),
+    ('quantized::linear_prepack_fp16', datetime.date(2020, 6, 1)),
+    ('quantized::linear_prepack', datetime.date(2020, 6, 1)),
+    ('quantized::linear_dynamic_fp16', datetime.date(2020, 6, 1)),
+    ('quantized::linear_relu_dynamic', datetime.date(2020, 6, 1)),
+    ('quantized::linear_dynamic', datetime.date(2020, 6, 1)),
+    ('quantized::linear_relu', datetime.date(2020, 6, 1)),
+    ('quantized::linear', datetime.date(2020, 6, 1)),
 ]
 
 
diff --git a/test/quantization/test_quantized_module.py b/test/quantization/test_quantized_module.py
index 89bb15f..e1539dc 100644
--- a/test/quantization/test_quantized_module.py
+++ b/test/quantization/test_quantized_module.py
@@ -112,15 +112,19 @@
 
         # Test serialization of quantized Linear Module using state_dict
         model_dict = qlinear.state_dict()
-        self.assertEqual(model_dict['_packed_params.weight'], W_q)
-        if use_bias:
-            self.assertEqual(model_dict['_packed_params.bias'], B)
         b = io.BytesIO()
         torch.save(model_dict, b)
         b.seek(0)
         loaded_dict = torch.load(b)
         for key in model_dict:
-            self.assertEqual(model_dict[key], loaded_dict[key])
+            if isinstance(model_dict[key], torch._C.ScriptObject):
+                assert isinstance(loaded_dict[key], torch._C.ScriptObject)
+                w_model, b_model = torch.ops.quantized.linear_unpack(model_dict[key])
+                w_loaded, b_loaded = torch.ops.quantized.linear_unpack(loaded_dict[key])
+                self.assertEqual(w_model, w_loaded)
+                self.assertEqual(b_model, b_loaded)
+            else:
+                self.assertEqual(model_dict[key], loaded_dict[key])
         if use_fused:
             loaded_qlinear = nnq_fused.LinearReLU(in_features, out_features)
         else:
@@ -158,9 +162,11 @@
         # self.assertEqual(qlinear.scale, loaded.scale)
         # self.assertEqual(qlinear.zero_point, loaded.zero_point)
         # <end code>
-        with self.assertRaisesRegex(RuntimeError, r'torch.save\(\) is not currently supported'):
-            b = io.BytesIO()
-            torch.save(qlinear, b)
+        #
+        # Currently disabled after TorchBind PR
+        # with self.assertRaisesRegex(RuntimeError, r'torch.save\(\) is not currently supported'):
+        #     b = io.BytesIO()
+        #     torch.save(qlinear, b)
 
         # Test JIT
         self.checkScriptable(qlinear, list(zip([X_q], [Z_ref])), check_save_load=True)
@@ -598,15 +604,19 @@
 
         # Test serialization of dynamic quantized Linear Module using state_dict
         model_dict = qlinear.state_dict()
-        self.assertEqual(model_dict['_packed_params.weight'], W_q)
-        if use_bias:
-            self.assertEqual(model_dict['_packed_params.bias'], B)
         b = io.BytesIO()
         torch.save(model_dict, b)
         b.seek(0)
         loaded_dict = torch.load(b)
         for key in model_dict:
-            self.assertEqual(model_dict[key], loaded_dict[key])
+            if isinstance(model_dict[key], torch._C.ScriptObject):
+                assert isinstance(loaded_dict[key], torch._C.ScriptObject)
+                w_model, b_model = torch.ops.quantized.linear_unpack(model_dict[key])
+                w_loaded, b_loaded = torch.ops.quantized.linear_unpack(loaded_dict[key])
+                self.assertEqual(w_model, w_loaded)
+                self.assertEqual(b_model, b_loaded)
+            else:
+                self.assertEqual(model_dict[key], loaded_dict[key])
         loaded_qlinear = nnqd.Linear(in_features, out_features)
         loaded_qlinear.load_state_dict(loaded_dict)
 
@@ -639,9 +649,9 @@
         # self.assertEqual(qlinear.weight(), loaded.weight())
         # self.assertEqual(qlinear.zero_point, loaded.zero_point)
         # <end code>
-        with self.assertRaisesRegex(RuntimeError, r'torch.save\(\) is not currently supported'):
-            b = io.BytesIO()
-            torch.save(qlinear, b)
+        # with self.assertRaisesRegex(RuntimeError, r'torch.save\(\) is not currently supported'):
+        #     b = io.BytesIO()
+        #     torch.save(qlinear, b)
 
         # Test JIT
         self.checkScriptable(qlinear, list(zip([X], [Z_ref])), check_save_load=True)
diff --git a/test/test_docs_coverage.py b/test/test_docs_coverage.py
index 88b4982..87efb4a 100644
--- a/test/test_docs_coverage.py
+++ b/test/test_docs_coverage.py
@@ -98,7 +98,7 @@
         in_rst = self.parse_rst('tensors.rst', r2)
         whitelist = {
             'names', 'unflatten', 'align_as', 'rename_', 'refine_names', 'align_to',
-            'has_names', 'rename',
+            'has_names', 'rename'
         }
         classes = [torch.FloatTensor, torch.LongTensor, torch.ByteTensor]
         has_docstring = set(x for c in classes for x in dir(c) if not x.startswith('_') and getattr(c, x).__doc__)
diff --git a/test/test_jit.py b/test/test_jit.py
index c1d7440..0496603 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -17523,8 +17523,7 @@
                 qweight = torch._empty_affine_quantized(
                     [out_features, in_features], scale=1, zero_point=0,
                     dtype=torch.qint8)
-                self.register_buffer('_packed_weight',
-                                     torch.ops.quantized.linear_prepack(qweight))
+                self._packed_weight = torch.ops.quantized.linear_prepack(qweight)
 
             @torch.jit.export
             def __getstate__(self):
@@ -17535,8 +17534,7 @@
 
             @torch.jit.export
             def __setstate__(self, state):
-                self._packed_weight.set_(
-                    torch.ops.quantized.linear_prepack(state[0]))
+                self._packed_weight = torch.ops.quantized.linear_prepack(state[0])
                 self.training = state[1]
 
             @property
diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp
index 54235d2..ed75845 100644
--- a/torch/csrc/jit/frontend/ir_emitter.cpp
+++ b/torch/csrc/jit/frontend/ir_emitter.cpp
@@ -2207,10 +2207,41 @@
         if (stmt.type().present()) {
           type = typeParser_.parseTypeFromExpr(stmt.type().get());
         }
+        auto rhs_sugared_val = emitSugaredExpr(rhs, 1, type);
+        // START BC HACK
+        //
+        // For old serialized quantized RNN modules, switch
+        // quantized::linear_prepack to quantized::linear_prepack_legacy. We
+        // changed linear_prepack to return a TorchBind class and not a
+        // cpp_custom_type_hack tensor anymore, but the old serialized models
+        // are tightly coupled with the type_hack version. If we still create a
+        // Tensor here, then the quantized_lstm.legacy overload can kick in in
+        // forward_impl(), and the module will still run correctly.
+        if (method.qualname() ==
+            "__torch__.torch.nn.quantized.dynamic.modules.rnn.PackedParameter.__setstate__") {
+          if (auto sv =
+                  std::dynamic_pointer_cast<SimpleValue>(rhs_sugared_val)) {
+            Node* rhs_node = sv->getValue()->node();
+            if (rhs_node->kind() ==
+                Symbol::fromQualString("quantized::linear_prepack")) {
+              std::vector<NamedValue> inputs;
+              for (Value* i : rhs_node->inputs()) {
+                inputs.emplace_back(i);
+              }
+              Value* new_val = rhs_node->owningGraph()->insert(
+                  Symbol::fromQualString("quantized::linear_prepack_legacy"),
+                  inputs,
+                  {},
+                  rhs_node->sourceRange());
+              rhs_sugared_val = std::make_shared<SimpleValue>(new_val);
+            }
+          }
+        }
+        // END BC HACK
         environment_stack->setSugaredVar(
             v.range(),
             v.name().name(),
-            emitSugaredExpr(rhs, 1, type),
+            std::move(rhs_sugared_val),
             /*annotated_type=*/type);
       } break;
       case TK_TUPLE_LITERAL:
diff --git a/torch/csrc/jit/passes/lower_graph.cpp b/torch/csrc/jit/passes/lower_graph.cpp
index c64238d..c19aee6 100644
--- a/torch/csrc/jit/passes/lower_graph.cpp
+++ b/torch/csrc/jit/passes/lower_graph.cpp
@@ -131,7 +131,10 @@
                "__torch__.torch.classes.quantized.Conv2dPackedParamsBase")) ||
               (type ==
                getCustomClass(
-                   "__torch__.torch.classes.quantized.Conv3dPackedParamsBase")),
+                   "__torch__.torch.classes.quantized.Conv3dPackedParamsBase")) ||
+              (type ==
+               getCustomClass(
+                   "__torch__.torch.classes.quantized.LinearPackedParamsBase")),
           "Unknown type ",
           type->python_str(),
           " encountered in graph lowering. This type is not supported in ONNX export.");
diff --git a/torch/csrc/jit/passes/onnx/constant_fold.h b/torch/csrc/jit/passes/onnx/constant_fold.h
index f9f3d20..0a76040 100644
--- a/torch/csrc/jit/passes/onnx/constant_fold.h
+++ b/torch/csrc/jit/passes/onnx/constant_fold.h
@@ -15,4 +15,5 @@
     int opset_version);
 
 } // namespace jit
+
 } // namespace torch
diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp
index ad75f96..e2a4b48 100644
--- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp
+++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp
@@ -1,4 +1,5 @@
 #include <torch/csrc/jit/passes/onnx/unpack_quantized_weights.h>
+#include <ATen/native/quantized/cpu/packed_params.h>
 #include <torch/csrc/jit/ir/constants.h>
 #include <torch/csrc/jit/ir/irparser.h>
 #include <torch/csrc/jit/ir/subgraph_matcher.h>
@@ -182,7 +183,7 @@
     c10::optional<int64_t> groups;
 
     if (itr->second.isTuple()) {
-      // Pre-unpacked weights. Comes from Conv weights which are
+      // Pre-unpacked weights. Comes from Conv/Linear weights which are
       // stored as bound C++ classes.
       auto ser_tup = itr->second.toTuple();
       unpacked_weight = ser_tup->elements()[0].toTensor();
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 9436e93..6bfe4ab 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -559,7 +559,10 @@
                       "__torch__.torch.classes.quantized.Conv2dPackedParamsBase") ||
               i->type() ==
                   getCustomClass(
-                      "__torch__.torch.classes.quantized.Conv3dPackedParamsBase")) {
+                      "__torch__.torch.classes.quantized.Conv3dPackedParamsBase") ||
+              i->type() ==
+                  getCustomClass(
+                      "__torch__.torch.classes.quantized.LinearPackedParamsBase")) {
             // Dummy CompleteTensorType to appease ONNX validator.
             i->setType(TensorType::create(
                 at::kQInt8,
diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp
index 462c720..2944960 100644
--- a/torch/csrc/jit/python/script_init.cpp
+++ b/torch/csrc/jit/python/script_init.cpp
@@ -443,14 +443,13 @@
     // Leave packed param types alone. This is needed for downstream passes
     // (like alias analysis) to work properly. This will be unpacked later
     // in unpackQuantizedWeights.
-    if ((v->type() ==
-         getCustomClass(
-             "__torch__.torch.classes.quantized.Conv2dPackedParamsBase")) ||
-        (v->type() ==
-         getCustomClass(
-             "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"))) {
-      s_iter++;
-      continue;
+    if (auto named_type = v->type()->cast<c10::NamedType>()) {
+      if (auto qualname = named_type->name()) {
+        if (getCustomClass(qualname->qualifiedName())) {
+          s_iter++;
+          continue;
+        }
+      }
     }
     if (v->type()->kind() == TupleType::Kind) {
       AT_ASSERT(v->node()->kind() == prim::Param);
diff --git a/torch/csrc/jit/serialization/import_source.cpp b/torch/csrc/jit/serialization/import_source.cpp
index 3a157ed..ecb6acf 100644
--- a/torch/csrc/jit/serialization/import_source.cpp
+++ b/torch/csrc/jit/serialization/import_source.cpp
@@ -262,35 +262,58 @@
     }
   }
 
-  c10::optional<Assign> qconvAttributeAssignmentSpecialHandlingHack(
+  c10::optional<Assign> attributeAssignmentSpecialHandlingHack(
       const QualifiedName& qualified_classname,
       const Assign& assign) {
-    static std::regex mangle_re("\\.___torch_mangle_\\d+");
-    auto replaced_string =
-        std::regex_replace(qualified_classname.qualifiedName(), mangle_re, "");
-    auto is_conv2d = [](const std::string& type) {
-      return type == "__torch__.torch.nn.quantized.modules.conv.Conv2d" ||
-          type ==
-          "__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d";
+    struct AttrTypeReplacementDescr {
+      std::string attr_name;
+      std::string expected_type;
+      std::string replacement_type;
     };
 
-    auto is_conv3d = [](const std::string& type) {
-      return type == "__torch__.torch.nn.quantized.modules.conv.Conv3d" ||
-          type ==
-          "__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU3d";
-    };
-    if (is_conv2d(replaced_string) || is_conv3d(replaced_string)) {
+    // module demangled qualname -> ReplacementDescr
+    static std::unordered_map<std::string, AttrTypeReplacementDescr> replacements{
+        {"__torch__.torch.nn.quantized.modules.linear.LinearPackedParams",
+         {"_packed_params",
+          "Tensor",
+          "__torch__.torch.classes.quantized.LinearPackedParamsBase"}},
+        {"__torch__.torch.nn.quantized.modules.linear.Linear",
+         {"_packed_params",
+          "Tensor",
+          "__torch__.torch.classes.quantized.LinearPackedParamsBase"}},
+        {"__torch__.torch.nn.quantized.modules.conv.Conv2d",
+         {"_packed_params",
+          "Tensor",
+          "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}},
+        {"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d",
+         {"_packed_params",
+          "Tensor",
+          "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}},
+        {"__torch__.torch.nn.quantized.modules.conv.Conv3d",
+         {"_packed_params",
+          "Tensor",
+          "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}},
+        {"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU3d",
+         {"_packed_params",
+          "Tensor",
+          "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}}};
+    static std::regex mangle_re("\\.___torch_mangle_\\d+");
+    auto demangled_classname =
+        std::regex_replace(qualified_classname.qualifiedName(), mangle_re, "");
+    if (replacements.count(demangled_classname)) {
       auto lhs = Var(assign.lhs());
       if (!assign.type().present() || assign.type().get().kind() != TK_VAR) {
         return c10::nullopt;
       }
       auto type = Var(assign.type().get());
-      if (lhs.name().name() == "_packed_params" &&
-          type.name().name() == "Tensor") {
-        std::string packed_params_typename = is_conv2d(replaced_string)
-            ? "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"
-            : "__torch__.torch.classes.quantized.Conv3dPackedParamsBase";
-        Parser p(std::make_shared<Source>(std::move(packed_params_typename)));
+
+      auto& attr_name = replacements.at(demangled_classname).attr_name;
+      auto& expected_type = replacements.at(demangled_classname).expected_type;
+      auto& replacement_type =
+          replacements.at(demangled_classname).replacement_type;
+      if (lhs.name().name() == attr_name &&
+          type.name().name() == expected_type) {
+        Parser p(std::make_shared<Source>(replacement_type));
         auto typename_expr = p.parseExp();
         auto maybe_typename =
             Maybe<Expr>::create(typename_expr.range(), typename_expr);
@@ -358,7 +381,7 @@
                 // This is to initialize the annotations dict, just ignore.
                 continue;
               } else {
-                if (auto fixed_up = qconvAttributeAssignmentSpecialHandlingHack(
+                if (auto fixed_up = attributeAssignmentSpecialHandlingHack(
                         qualified_classname, assign)) {
                   attributes.push_back(std::move(*fixed_up));
                 } else if (assign.rhs().present()) {
diff --git a/torch/jit/quantized.py b/torch/jit/quantized.py
index d222dec..fdc2132 100644
--- a/torch/jit/quantized.py
+++ b/torch/jit/quantized.py
@@ -292,11 +292,12 @@
                         weight_ih, weight_hh, bias_ih, bias_hh)
                 else:
                     packed_ih = torch.ops.quantized.linear_prepack_fp16(
-                        weight_ih.float())
+                        weight_ih.float(), bias_ih)
                     packed_hh = torch.ops.quantized.linear_prepack_fp16(
-                        weight_hh.float())
+                        weight_hh.float(), bias_hh)
 
-                    cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(packed_ih, packed_hh, bias_ih, bias_hh)
+                    cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
+                        packed_ih, packed_hh)
 
                 setattr(self, 'cell_params_{}_{}'.format(layer, suffix), cell_params)
                 self.all_weights.append(cell_params)
diff --git a/torch/nn/quantized/dynamic/modules/rnn.py b/torch/nn/quantized/dynamic/modules/rnn.py
index a39b54c..2dd40d9 100644
--- a/torch/nn/quantized/dynamic/modules/rnn.py
+++ b/torch/nn/quantized/dynamic/modules/rnn.py
@@ -84,10 +84,10 @@
                     # bias vector is needed in standard definition.
                     b_hh = torch.Tensor(gate_size).float()
 
-                    packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih)
-                    packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh)
+                    packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih)
+                    packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh)
                     cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
-                        packed_ih, packed_hh, b_ih, b_hh)
+                        packed_ih, packed_hh)
 
                 _all_weight_values.append(PackedParameter(cell_params))
         self._all_weight_values = torch.nn.ModuleList(_all_weight_values)
@@ -247,12 +247,12 @@
                         packed_ih, packed_hh, bias_ih, bias_hh)
                 else:
                     packed_ih = torch.ops.quantized.linear_prepack_fp16(
-                        weight_ih.float())
+                        weight_ih.float(), bias_ih)
                     packed_hh = torch.ops.quantized.linear_prepack_fp16(
-                        weight_hh.float())
+                        weight_hh.float(), bias_hh)
 
                     cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
-                        packed_ih, packed_hh, bias_ih, bias_hh)
+                        packed_ih, packed_hh)
 
                 _all_weight_values.append(PackedParameter(cell_params))
         qRNNBase._all_weight_values = torch.nn.ModuleList(_all_weight_values)
diff --git a/torch/nn/quantized/modules/linear.py b/torch/nn/quantized/modules/linear.py
index a64de54..924a8ef 100644
--- a/torch/nn/quantized/modules/linear.py
+++ b/torch/nn/quantized/modules/linear.py
@@ -8,7 +8,7 @@
 from torch.nn.quantized.modules.utils import _quantize_weight
 
 class LinearPackedParams(torch.nn.Module):
-    _version = 2
+    _version = 3
 
     def __init__(self, dtype=torch.qint8):
         super(LinearPackedParams, self).__init__()
@@ -42,19 +42,29 @@
     def forward(self, x):
         return x
 
+    # Version 1
+    #   self
+    #   |--- weight : Tensor
+    #   |--- bias : Tensor
+    #
+    # Version 2
+    #   self
+    #   |--- weight : Tensor
+    #   |--- bias : Tensor
+    #   |--- dtype : torch.dtype
+    #
+    # Version 3
+    #   self
+    #   |--- _packed_params : (Tensor, Tensor) representing (weight, bias)
+    #                         of LinearPackedParams
+    #   |--- dtype : torch.dtype
     def _save_to_state_dict(self, destination, prefix, keep_vars):
         super(LinearPackedParams, self)._save_to_state_dict(destination, prefix, keep_vars)
-        (w, b) = self._weight_bias()
-        destination[prefix + 'weight'] = w
-        destination[prefix + 'bias'] = b
         destination[prefix + 'dtype'] = self.dtype
+        destination[prefix + '_packed_params'] = self._weight_bias()
 
     def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                               missing_keys, unexpected_keys, error_msgs):
-        self.set_weight_bias(state_dict[prefix + 'weight'], state_dict[prefix + 'bias'])
-        state_dict.pop(prefix + 'weight')
-        state_dict.pop(prefix + 'bias')
-
         version = local_metadata.get('version', None)
         if version is None or version < 2:
             self.dtype = torch.qint8
@@ -62,6 +72,16 @@
             self.dtype = state_dict[prefix + 'dtype']
             state_dict.pop(prefix + 'dtype')
 
+        if version is None or version < 3:
+            self.set_weight_bias(state_dict[prefix + 'weight'], state_dict[prefix + 'bias'])
+            state_dict.pop(prefix + 'weight')
+            state_dict.pop(prefix + 'bias')
+
+        if version == 3:
+            weight, bias = state_dict[prefix + '_packed_params']
+            state_dict.pop(prefix + '_packed_params')
+            self.set_weight_bias(weight, bias)
+
         super(LinearPackedParams, self)._load_from_state_dict(state_dict, prefix, local_metadata, False,
                                                               missing_keys, unexpected_keys, error_msgs)
 
@@ -110,7 +130,7 @@
         >>> print(output.size())
         torch.Size([128, 30])
     """
-    _version = 2
+    _version = 3
     _FLOAT_MODULE = nn.Linear
 
     def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8):
@@ -185,6 +205,30 @@
     # regular QTensor form for serialization. Packed weights should not live
     # outside the process in which they were created, rather they should be derived
     # from the QTensor weight.
+    #
+    # Version 1
+    #   self
+    #   |--- scale : float
+    #   |--- zero_point : int
+    #   |--- weight : Tensor
+    #   |--- bias : Tensor
+    #
+    # Version 2
+    #   self
+    #   |--- scale : float
+    #   |--- zero_point : int
+    #   |--- _packed_params : Module
+    #        |--- weight : Tensor
+    #        |--- bias : Tensor
+    #
+    # Version 3
+    #   self
+    #   |--- scale : float
+    #   |--- zero_point : int
+    #   |--- _packed_params : Module
+    #        |--- _packed_params : (Tensor, Tensor) representing weight, bias
+    #                              of LinearPackedParams C++ struct
+    #
     def _save_to_state_dict(self, destination, prefix, keep_vars):
         super(Linear, self)._save_to_state_dict(destination, prefix, keep_vars)
         destination[prefix + 'scale'] = torch.tensor(self.scale)