[quant] Implement PTQ for APoT FakeQuant (#81040)

### Summary:
This PR implements PTQ for APoT FakeQuant. It runs models (Resnet-18 pre-trained model, ImageNet dataset) to compare accuracy metrics for different qconfig settings of uniform vs. APoT quantized activation and weight.

According to the collected accuracy stats, model #2 (uniform activation and APoT weight) appears to have a slight improvement in accuracy compared to model #1 (uniform activation and uniform weight) for 8-bit and significant improvement for 4-bit (see "Accuracy Stats" section below).

### Test Plan:
Run models with: `python test/quantization/core/experimental/fx_graph_mode_apot.py`

### Accuracy Stats:
8-bit (Uniform int8, APoT b = 8 k = 2)

**Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.43% (Top-1), 85.62% (Top-5)

**Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.51% (Top-1), 85.78% (Top-5)

**Model #3:** APoT activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.32% (Top-1), 85.78% (Top-5)

4-bit (Uniform int4, APoT b = 4 k = 2)

**Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 45.63% (Top-1), 71.96% (Top-5)

**Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.24% (Top-1), 85.56% (Top-5)

**Model #3:** APoT activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 45.40% (Top-1), 76.21% (Top-5)

**Full Precision model (FX Graph Mode quantized)**
Evaluation accuracy on test dataset: 69.76% (Top-1), 89.08% (Top-5)

**Eager mode quantized model**
Evaluation accuracy on test dataset: 69.49% (Top-1), 88.90% (Top-5)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81040
Approved by: https://github.com/jerryzh168
diff --git a/mypy.ini b/mypy.ini
index 248f52f..eb6b502 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -73,6 +73,9 @@
 [mypy-torch.ao.quantization.experimental.fake_quantize_function]
 ignore_missing_imports = True
 
+[mypy-torch.ao.quantization.experimental.fake_quantize]
+ignore_missing_imports = True
+
 #
 # Files with various errors. Mostly real errors, possibly some false
 # positives as well.
diff --git a/test/quantization/core/experimental/data/resnet18_pretrained_float.pth b/test/quantization/core/experimental/data/resnet18_pretrained_float.pth
new file mode 100644
index 0000000..c049f92
--- /dev/null
+++ b/test/quantization/core/experimental/data/resnet18_pretrained_float.pth
Binary files differ
diff --git a/test/quantization/core/experimental/fx_graph_mode_apot.py b/test/quantization/core/experimental/fx_graph_mode_apot.py
new file mode 100644
index 0000000..2ef7168
--- /dev/null
+++ b/test/quantization/core/experimental/fx_graph_mode_apot.py
@@ -0,0 +1,257 @@
+import torch
+import torch.nn as nn
+import torchvision
+import torchvision.transforms.transforms as transforms
+import os
+import torch.quantization
+
+# Setup warnings
+import warnings
+warnings.filterwarnings(
+    action='ignore',
+    category=DeprecationWarning,
+    module=r'.*'
+)
+warnings.filterwarnings(
+    action='default',
+    module=r'torch.quantization'
+)
+
+"""
+Define helper functions
+"""
+
+# Specify random seed for repeatable results
+_ = torch.manual_seed(191009)
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+    def __init__(self, name, fmt=':f'):
+        self.name = name
+        self.fmt = fmt
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+    def __str__(self):
+        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
+        return fmtstr.format(**self.__dict__)
+
+
+def accuracy(output, target, topk=(1,)):
+    """Computes the accuracy over the k top predictions for the specified values of k"""
+    with torch.no_grad():
+        maxk = max(topk)
+        batch_size = target.size(0)
+
+        _, pred = output.topk(maxk, 1, True, True)
+        pred = pred.t()
+        correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+        res = []
+        for k in topk:
+            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
+            res.append(correct_k.mul_(100.0 / batch_size))
+        return res
+
+
+def evaluate(model, criterion, data_loader):
+    model.eval()
+    top1 = AverageMeter('Acc@1', ':6.2f')
+    top5 = AverageMeter('Acc@5', ':6.2f')
+    cnt = 0
+    with torch.no_grad():
+        for image, target in data_loader:
+            output = model(image)
+            loss = criterion(output, target)
+            cnt += 1
+            acc1, acc5 = accuracy(output, target, topk=(1, 5))
+            top1.update(acc1[0], image.size(0))
+            top5.update(acc5[0], image.size(0))
+    print('')
+
+    return top1, top5
+
+def load_model(model_file):
+    model = resnet18(pretrained=False)
+    state_dict = torch.load(model_file)
+    model.load_state_dict(state_dict)
+    model.to("cpu")
+    return model
+
+def print_size_of_model(model):
+    if isinstance(model, torch.jit.RecursiveScriptModule):
+        torch.jit.save(model, "temp.p")
+    else:
+        torch.jit.save(torch.jit.script(model), "temp.p")
+    print("Size (MB):", os.path.getsize("temp.p") / 1e6)
+    os.remove("temp.p")
+
+def prepare_data_loaders(data_path):
+
+    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
+                                     std=[0.229, 0.224, 0.225])
+    dataset = torchvision.datasets.ImageNet(data_path,
+                                            split="train",
+                                            transform=transforms.Compose([transforms.RandomResizedCrop(224),
+                                                                          transforms.RandomHorizontalFlip(),
+                                                                          transforms.ToTensor(),
+                                                                          normalize]))
+    dataset_test = torchvision.datasets.ImageNet(data_path,
+                                                 split="val",
+                                                 transform=transforms.Compose([transforms.Resize(256),
+                                                                               transforms.CenterCrop(224),
+                                                                               transforms.ToTensor(),
+                                                                               normalize]))
+
+    train_sampler = torch.utils.data.RandomSampler(dataset)
+    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+
+    data_loader = torch.utils.data.DataLoader(
+        dataset, batch_size=train_batch_size,
+        sampler=train_sampler)
+
+    data_loader_test = torch.utils.data.DataLoader(
+        dataset_test, batch_size=eval_batch_size,
+        sampler=test_sampler)
+
+    return data_loader, data_loader_test
+
+data_path = '~/my_imagenet/'
+saved_model_dir = '/data/home/amandaliu/cluster/pytorch/test/quantization/core/experimental/data/'
+float_model_file = 'resnet18_pretrained_float.pth'
+
+train_batch_size = 30
+eval_batch_size = 50
+
+data_loader, data_loader_test = prepare_data_loaders(data_path)
+criterion = nn.CrossEntropyLoss()
+float_model = load_model(saved_model_dir + float_model_file).to("cpu")
+float_model.eval()
+
+# deepcopy the model since we need to keep the original model around
+import copy
+model_to_quantize = copy.deepcopy(float_model)
+
+model_to_quantize.eval()
+
+"""
+Prepare models
+"""
+
+# Note that this is temporary, we'll expose these functions to torch.quantization after official releasee
+from torch.quantization.quantize_fx import prepare_fx, convert_fx
+
+def calibrate(model, data_loader):
+    model.eval()
+    with torch.no_grad():
+        for image, target in data_loader:
+            model(image)
+
+from torch.ao.quantization.experimental.qconfig import (
+    uniform_qconfig_8bit,
+    apot_weights_qconfig_8bit,
+    apot_qconfig_8bit,
+    uniform_qconfig_4bit,
+    apot_weights_qconfig_4bit,
+    apot_qconfig_4bit
+)
+
+"""
+Prepare full precision model
+"""
+full_precision_model = float_model
+
+top1, top5 = evaluate(full_precision_model, criterion, data_loader_test)
+print("Model #0 Evaluation accuracy on test dataset: %2.2f, %2.2f" % (top1.avg, top5.avg))
+
+"""
+Prepare model PTQ for specified qconfig for torch.nn.Linear
+"""
+def prepare_ptq_linear(qconfig):
+    qconfig_dict = {"object_type": [(torch.nn.Linear, qconfig)]}
+    prepared_model = prepare_fx(copy.deepcopy(float_model), qconfig_dict)  # fuse modules and insert observers
+    calibrate(prepared_model, data_loader_test)  # run calibration on sample data
+    return prepared_model
+
+"""
+Prepare model with uniform activation, uniform weight
+b=8, k=2
+"""
+
+prepared_model = prepare_ptq_linear(uniform_qconfig_8bit)
+quantized_model = convert_fx(prepared_model)  # convert the calibrated model to a quantized model
+
+top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
+print("Model #1 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
+
+"""
+Prepare model with uniform activation, uniform weight
+b=4, k=2
+"""
+
+prepared_model = prepare_ptq_linear(uniform_qconfig_4bit)
+quantized_model = convert_fx(prepared_model)  # convert the calibrated model to a quantized model
+
+top1, top5 = evaluate(quantized_model1, criterion, data_loader_test)
+print("Model #1 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
+
+"""
+Prepare model with uniform activation, APoT weight
+(b=8, k=2)
+"""
+
+prepared_model = prepare_ptq_linear(apot_weights_qconfig_8bit)
+
+top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
+print("Model #2 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
+
+"""
+Prepare model with uniform activation, APoT weight
+(b=4, k=2)
+"""
+
+prepared_model = prepare_ptq_linear(apot_weights_qconfig_4bit)
+
+top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
+print("Model #2 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
+
+
+"""
+Prepare model with APoT activation and weight
+(b=8, k=2)
+"""
+
+prepared_model = prepare_ptq_linear(apot_qconfig_8bit)
+
+top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
+print("Model #3 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
+
+"""
+Prepare model with APoT activation and weight
+(b=4, k=2)
+"""
+
+prepared_model = prepare_ptq_linear(apot_qconfig_4bit)
+
+top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
+print("Model #3 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
+
+"""
+Prepare eager mode quantized model
+"""
+
+from torchvision.models.quantization.resnet import resnet18
+eager_quantized_model = resnet18(pretrained=True, quantize=True).eval()
+top1, top5 = evaluate(eager_quantized_model, criterion, data_loader_test)
+print("Eager mode quantized model evaluation accuracy on test dataset: %2.2f, %2.2f" % (top1.avg, top5.avg))
diff --git a/test/quantization/core/experimental/test_fake_quantize.py b/test/quantization/core/experimental/test_fake_quantize.py
index fab526c..4e9464a 100644
--- a/test/quantization/core/experimental/test_fake_quantize.py
+++ b/test/quantization/core/experimental/test_fake_quantize.py
@@ -35,7 +35,7 @@
 
     r""" Tests fake quantize forward() method
          by comparing result with expected
-         float_to_reduced_precision mapping of input tensor.
+         quant_dequant_APoT mapping of input tensor.
          Uses input tensor with random values from 0 -> 1000
          and APoT observer with hard-coded values b=4, k=2
     """
diff --git a/torch/ao/quantization/experimental/apot_utils.py b/torch/ao/quantization/experimental/apot_utils.py
index 9804cd8..ad7a7be 100644
--- a/torch/ao/quantization/experimental/apot_utils.py
+++ b/torch/ao/quantization/experimental/apot_utils.py
@@ -33,7 +33,7 @@
     reduced precision floating point value
     based on quantization levels
 """
-def float_to_reduced_precision(x, levels, indices):
+def quant_dequant_util(x, levels, indices):
     levels_lst = list(levels)
     indices_lst = list(indices)
 
diff --git a/torch/ao/quantization/experimental/fake_quantize.py b/torch/ao/quantization/experimental/fake_quantize.py
index c229859..7541106 100644
--- a/torch/ao/quantization/experimental/fake_quantize.py
+++ b/torch/ao/quantization/experimental/fake_quantize.py
@@ -10,18 +10,23 @@
     quantization_levels: Tensor
     level_indices: Tensor
 
-    def __init__(self, **observer_kwargs):
+    def __init__(self, observer=APoTObserver, **observer_kwargs):
         super().__init__()
-        self.activation_post_process = APoTObserver(**observer_kwargs)
+        self.activation_post_process = observer(**observer_kwargs)
+        self.dtype = self.activation_post_process.dtype
 
-    def calculate_qparams(self, signed: bool):  # type: ignore[override]
+    def calculate_qparams(self, signed=False):  # type: ignore[override]
         return self.activation_post_process.calculate_qparams(signed=signed)
 
-    def forward(self, X: torch.Tensor, signed: bool):  # type: ignore[override]
+    def forward(self, X: torch.Tensor):  # type: ignore[override]
         if self.observer_enabled[0] == 1:
             self.activation_post_process.forward(X)
-            self.alpha, self.gamma, self.quantization_levels, self.level_indices = \
-                self.activation_post_process.calculate_qparams(signed)
+            result = self.activation_post_process.calculate_qparams(signed=False)
+            self.alpha = result[0]
+            self.gamma = result[1]
+            self.quantization_levels = result[2]
+            self.level_indices = result[3]
+
         if self.fake_quant_enabled[0] == 1:
             assert (self.alpha is not None
                     and self.gamma is not None
diff --git a/torch/ao/quantization/experimental/observer.py b/torch/ao/quantization/experimental/observer.py
index 77cb594..244975b 100644
--- a/torch/ao/quantization/experimental/observer.py
+++ b/torch/ao/quantization/experimental/observer.py
@@ -23,7 +23,7 @@
         self,
         b,
         k,
-            dtype=torch.int32) -> None:
+            dtype=torch.quint8) -> None:
         super().__init__(dtype)
         self.b = b
         self.k = k
@@ -47,7 +47,7 @@
         quantization_levels: non-uniform quantization levels (fp representation)
         level_indices: int representation of quantization_levels indices
     """
-    def _calculate_qparams(self, signed, min_val=None, max_val=None):
+    def _calculate_qparams(self, signed: bool, min_val=None, max_val=None):
         if min_val is not None:
             self.min_val = min_val
         if max_val is not None:
diff --git a/torch/ao/quantization/experimental/qconfig.py b/torch/ao/quantization/experimental/qconfig.py
new file mode 100644
index 0000000..f9397d1
--- /dev/null
+++ b/torch/ao/quantization/experimental/qconfig.py
@@ -0,0 +1,46 @@
+import torch
+from torch.ao.quantization.qconfig import QConfig
+from torch.ao.quantization import MinMaxObserver
+from torch.ao.quantization.fake_quantize import FakeQuantize
+from torch.ao.quantization.experimental.fake_quantize import APoTFakeQuantize
+
+"""
+Default symmetric fake_quant for activations.
+"""
+default_symmetric_fake_quant = FakeQuantize.with_args(observer=MinMaxObserver,
+                                                      qscheme=torch.per_tensor_symmetric,
+                                                      dtype=torch.quint8)
+
+"""
+Default symmetric fake_quant for weights.
+"""
+default_weight_symmetric_fake_quant = FakeQuantize.with_args(observer=MinMaxObserver,
+                                                             qscheme=torch.per_tensor_symmetric,
+                                                             dtype=torch.qint8)
+
+# uniform activation and weight, b=8 k=2
+uniform_qconfig_8bit = QConfig(activation=default_symmetric_fake_quant,
+                               weight=default_weight_symmetric_fake_quant.with_args)
+
+# uniform activation, APoT weight, b=8 k=2
+apot_weight_qconfig_8bit = QConfig(activation=default_symmetric_fake_quant.with_args,
+                                   weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8))
+
+# APoT activation and uniform weight, b=8 k=2
+apot_qconfig_8bit = QConfig(activation=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.quint8),
+                            weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8))
+
+# uniform activation and weight, b=4 k=2
+uniform_qconfig_4bit = QConfig(activation=default_symmetric_fake_quant.with_args(quant_min=0,
+                                                                                 quant_max=15),
+                               weight=default_weight_symmetric_fake_quant.with_args(quant_min=0,
+                                                                                    quant_max=15))
+
+# uniform activation, APoT weight, b=4 k=2
+apot_weight_qconfig_4bit = QConfig(activation=default_symmetric_fake_quant.with_args(quant_min=0,
+                                                                                     quant_max=15),
+                                   weight=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.qint8))
+
+# APoT activation and uniform weight, b=4 k=2
+apot_qconfig_4bit = QConfig(activation=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.quint8),
+                            weight=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.qint8))
diff --git a/torch/ao/quantization/experimental/quantizer.py b/torch/ao/quantization/experimental/quantizer.py
index 3894435..1d8845c 100644
--- a/torch/ao/quantization/experimental/quantizer.py
+++ b/torch/ao/quantization/experimental/quantizer.py
@@ -1,7 +1,7 @@
 import torch
 from torch import Tensor
 import numpy as np
-from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
+from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float, quant_dequant_util
 
 # class to store APoT quantizer and
 # implement quantize and dequantize
@@ -52,9 +52,9 @@
     based on the calculated quantization levels from a specified APoT non-uniform observer.
     The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
     Args:
-        apot_tensor: quantized APoT Tensor to dequantize
+        tensor2quantize: fp Tensor
     Returns:
-        result: fp representation of input Tensor
+        result: fp reduced precision representation of input Tensor
     """
     def dequantize(self, apot_tensor) -> Tensor:
         orig_size = apot_tensor.data.size()
@@ -72,6 +72,21 @@
 
         return result
 
+    r""" Returns result of quantize -> dequantize on a fp Tensor (reduced precision)
+    based on the calculated quantization levels from a specified APoT non-uniform observer.
+    The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
+    Args:
+        apot_tensor: quantized APoT Tensor to dequantize
+    Returns:
+        result: fp representation of input Tensor
+    """
+    def quant_dequant(self, tensor2quantize: Tensor) -> Tensor:
+        levels_lst = list(self.quantization_levels)
+
+        result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst))
+
+        return result
+
     def q_apot_alpha(self) -> float:
         raise NotImplementedError
 
@@ -100,3 +115,22 @@
     quantizer = apot_tensor.quantizer
     result = quantizer.dequantize(apot_tensor)
     return result
+
+r""" Global method to create quantizer and call quantizer quant_dequant
+    Args:
+        tensor2quantize: fp Tensor to quantize
+        alpha: Tensor qparam alpha (clipping level)
+        gamma: Tensor qparam gamma (scale factor for quantization levels)
+        quantization levels: Tensor with fp quantization levels
+        level indices: Tensor with integer quantization level indices
+    Returns:
+        result: fp reduced precision Tensor from tensor2quantize
+"""
+def quant_dequant_APoT(tensor2quantize: Tensor,
+                       alpha: Tensor,
+                       gamma: Tensor,
+                       quantization_levels: Tensor,
+                       level_indices: Tensor) -> Tensor:
+    quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
+    result = quantizer.quant_dequant(tensor2quantize)
+    return result