[quant] Implement PTQ for APoT FakeQuant (#81040)
### Summary:
This PR implements PTQ for APoT FakeQuant. It runs models (Resnet-18 pre-trained model, ImageNet dataset) to compare accuracy metrics for different qconfig settings of uniform vs. APoT quantized activation and weight.
According to the collected accuracy stats, model #2 (uniform activation and APoT weight) appears to have a slight improvement in accuracy compared to model #1 (uniform activation and uniform weight) for 8-bit and significant improvement for 4-bit (see "Accuracy Stats" section below).
### Test Plan:
Run models with: `python test/quantization/core/experimental/fx_graph_mode_apot.py`
### Accuracy Stats:
8-bit (Uniform int8, APoT b = 8 k = 2)
**Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.43% (Top-1), 85.62% (Top-5)
**Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.51% (Top-1), 85.78% (Top-5)
**Model #3:** APoT activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.32% (Top-1), 85.78% (Top-5)
4-bit (Uniform int4, APoT b = 4 k = 2)
**Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 45.63% (Top-1), 71.96% (Top-5)
**Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.24% (Top-1), 85.56% (Top-5)
**Model #3:** APoT activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 45.40% (Top-1), 76.21% (Top-5)
**Full Precision model (FX Graph Mode quantized)**
Evaluation accuracy on test dataset: 69.76% (Top-1), 89.08% (Top-5)
**Eager mode quantized model**
Evaluation accuracy on test dataset: 69.49% (Top-1), 88.90% (Top-5)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81040
Approved by: https://github.com/jerryzh168
diff --git a/mypy.ini b/mypy.ini
index 248f52f..eb6b502 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -73,6 +73,9 @@
[mypy-torch.ao.quantization.experimental.fake_quantize_function]
ignore_missing_imports = True
+[mypy-torch.ao.quantization.experimental.fake_quantize]
+ignore_missing_imports = True
+
#
# Files with various errors. Mostly real errors, possibly some false
# positives as well.
diff --git a/test/quantization/core/experimental/data/resnet18_pretrained_float.pth b/test/quantization/core/experimental/data/resnet18_pretrained_float.pth
new file mode 100644
index 0000000..c049f92
--- /dev/null
+++ b/test/quantization/core/experimental/data/resnet18_pretrained_float.pth
Binary files differ
diff --git a/test/quantization/core/experimental/fx_graph_mode_apot.py b/test/quantization/core/experimental/fx_graph_mode_apot.py
new file mode 100644
index 0000000..2ef7168
--- /dev/null
+++ b/test/quantization/core/experimental/fx_graph_mode_apot.py
@@ -0,0 +1,257 @@
+import torch
+import torch.nn as nn
+import torchvision
+import torchvision.transforms.transforms as transforms
+import os
+import torch.quantization
+
+# Setup warnings
+import warnings
+warnings.filterwarnings(
+ action='ignore',
+ category=DeprecationWarning,
+ module=r'.*'
+)
+warnings.filterwarnings(
+ action='default',
+ module=r'torch.quantization'
+)
+
+"""
+Define helper functions
+"""
+
+# Specify random seed for repeatable results
+_ = torch.manual_seed(191009)
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self, name, fmt=':f'):
+ self.name = name
+ self.fmt = fmt
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+ def __str__(self):
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
+ return fmtstr.format(**self.__dict__)
+
+
+def accuracy(output, target, topk=(1,)):
+ """Computes the accuracy over the k top predictions for the specified values of k"""
+ with torch.no_grad():
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+
+def evaluate(model, criterion, data_loader):
+ model.eval()
+ top1 = AverageMeter('Acc@1', ':6.2f')
+ top5 = AverageMeter('Acc@5', ':6.2f')
+ cnt = 0
+ with torch.no_grad():
+ for image, target in data_loader:
+ output = model(image)
+ loss = criterion(output, target)
+ cnt += 1
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
+ top1.update(acc1[0], image.size(0))
+ top5.update(acc5[0], image.size(0))
+ print('')
+
+ return top1, top5
+
+def load_model(model_file):
+ model = resnet18(pretrained=False)
+ state_dict = torch.load(model_file)
+ model.load_state_dict(state_dict)
+ model.to("cpu")
+ return model
+
+def print_size_of_model(model):
+ if isinstance(model, torch.jit.RecursiveScriptModule):
+ torch.jit.save(model, "temp.p")
+ else:
+ torch.jit.save(torch.jit.script(model), "temp.p")
+ print("Size (MB):", os.path.getsize("temp.p") / 1e6)
+ os.remove("temp.p")
+
+def prepare_data_loaders(data_path):
+
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ dataset = torchvision.datasets.ImageNet(data_path,
+ split="train",
+ transform=transforms.Compose([transforms.RandomResizedCrop(224),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ normalize]))
+ dataset_test = torchvision.datasets.ImageNet(data_path,
+ split="val",
+ transform=transforms.Compose([transforms.Resize(256),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ normalize]))
+
+ train_sampler = torch.utils.data.RandomSampler(dataset)
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+
+ data_loader = torch.utils.data.DataLoader(
+ dataset, batch_size=train_batch_size,
+ sampler=train_sampler)
+
+ data_loader_test = torch.utils.data.DataLoader(
+ dataset_test, batch_size=eval_batch_size,
+ sampler=test_sampler)
+
+ return data_loader, data_loader_test
+
+data_path = '~/my_imagenet/'
+saved_model_dir = '/data/home/amandaliu/cluster/pytorch/test/quantization/core/experimental/data/'
+float_model_file = 'resnet18_pretrained_float.pth'
+
+train_batch_size = 30
+eval_batch_size = 50
+
+data_loader, data_loader_test = prepare_data_loaders(data_path)
+criterion = nn.CrossEntropyLoss()
+float_model = load_model(saved_model_dir + float_model_file).to("cpu")
+float_model.eval()
+
+# deepcopy the model since we need to keep the original model around
+import copy
+model_to_quantize = copy.deepcopy(float_model)
+
+model_to_quantize.eval()
+
+"""
+Prepare models
+"""
+
+# Note that this is temporary, we'll expose these functions to torch.quantization after official releasee
+from torch.quantization.quantize_fx import prepare_fx, convert_fx
+
+def calibrate(model, data_loader):
+ model.eval()
+ with torch.no_grad():
+ for image, target in data_loader:
+ model(image)
+
+from torch.ao.quantization.experimental.qconfig import (
+ uniform_qconfig_8bit,
+ apot_weights_qconfig_8bit,
+ apot_qconfig_8bit,
+ uniform_qconfig_4bit,
+ apot_weights_qconfig_4bit,
+ apot_qconfig_4bit
+)
+
+"""
+Prepare full precision model
+"""
+full_precision_model = float_model
+
+top1, top5 = evaluate(full_precision_model, criterion, data_loader_test)
+print("Model #0 Evaluation accuracy on test dataset: %2.2f, %2.2f" % (top1.avg, top5.avg))
+
+"""
+Prepare model PTQ for specified qconfig for torch.nn.Linear
+"""
+def prepare_ptq_linear(qconfig):
+ qconfig_dict = {"object_type": [(torch.nn.Linear, qconfig)]}
+ prepared_model = prepare_fx(copy.deepcopy(float_model), qconfig_dict) # fuse modules and insert observers
+ calibrate(prepared_model, data_loader_test) # run calibration on sample data
+ return prepared_model
+
+"""
+Prepare model with uniform activation, uniform weight
+b=8, k=2
+"""
+
+prepared_model = prepare_ptq_linear(uniform_qconfig_8bit)
+quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
+
+top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
+print("Model #1 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
+
+"""
+Prepare model with uniform activation, uniform weight
+b=4, k=2
+"""
+
+prepared_model = prepare_ptq_linear(uniform_qconfig_4bit)
+quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
+
+top1, top5 = evaluate(quantized_model1, criterion, data_loader_test)
+print("Model #1 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
+
+"""
+Prepare model with uniform activation, APoT weight
+(b=8, k=2)
+"""
+
+prepared_model = prepare_ptq_linear(apot_weights_qconfig_8bit)
+
+top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
+print("Model #2 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
+
+"""
+Prepare model with uniform activation, APoT weight
+(b=4, k=2)
+"""
+
+prepared_model = prepare_ptq_linear(apot_weights_qconfig_4bit)
+
+top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
+print("Model #2 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
+
+
+"""
+Prepare model with APoT activation and weight
+(b=8, k=2)
+"""
+
+prepared_model = prepare_ptq_linear(apot_qconfig_8bit)
+
+top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
+print("Model #3 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
+
+"""
+Prepare model with APoT activation and weight
+(b=4, k=2)
+"""
+
+prepared_model = prepare_ptq_linear(apot_qconfig_4bit)
+
+top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
+print("Model #3 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
+
+"""
+Prepare eager mode quantized model
+"""
+
+from torchvision.models.quantization.resnet import resnet18
+eager_quantized_model = resnet18(pretrained=True, quantize=True).eval()
+top1, top5 = evaluate(eager_quantized_model, criterion, data_loader_test)
+print("Eager mode quantized model evaluation accuracy on test dataset: %2.2f, %2.2f" % (top1.avg, top5.avg))
diff --git a/test/quantization/core/experimental/test_fake_quantize.py b/test/quantization/core/experimental/test_fake_quantize.py
index fab526c..4e9464a 100644
--- a/test/quantization/core/experimental/test_fake_quantize.py
+++ b/test/quantization/core/experimental/test_fake_quantize.py
@@ -35,7 +35,7 @@
r""" Tests fake quantize forward() method
by comparing result with expected
- float_to_reduced_precision mapping of input tensor.
+ quant_dequant_APoT mapping of input tensor.
Uses input tensor with random values from 0 -> 1000
and APoT observer with hard-coded values b=4, k=2
"""
diff --git a/torch/ao/quantization/experimental/apot_utils.py b/torch/ao/quantization/experimental/apot_utils.py
index 9804cd8..ad7a7be 100644
--- a/torch/ao/quantization/experimental/apot_utils.py
+++ b/torch/ao/quantization/experimental/apot_utils.py
@@ -33,7 +33,7 @@
reduced precision floating point value
based on quantization levels
"""
-def float_to_reduced_precision(x, levels, indices):
+def quant_dequant_util(x, levels, indices):
levels_lst = list(levels)
indices_lst = list(indices)
diff --git a/torch/ao/quantization/experimental/fake_quantize.py b/torch/ao/quantization/experimental/fake_quantize.py
index c229859..7541106 100644
--- a/torch/ao/quantization/experimental/fake_quantize.py
+++ b/torch/ao/quantization/experimental/fake_quantize.py
@@ -10,18 +10,23 @@
quantization_levels: Tensor
level_indices: Tensor
- def __init__(self, **observer_kwargs):
+ def __init__(self, observer=APoTObserver, **observer_kwargs):
super().__init__()
- self.activation_post_process = APoTObserver(**observer_kwargs)
+ self.activation_post_process = observer(**observer_kwargs)
+ self.dtype = self.activation_post_process.dtype
- def calculate_qparams(self, signed: bool): # type: ignore[override]
+ def calculate_qparams(self, signed=False): # type: ignore[override]
return self.activation_post_process.calculate_qparams(signed=signed)
- def forward(self, X: torch.Tensor, signed: bool): # type: ignore[override]
+ def forward(self, X: torch.Tensor): # type: ignore[override]
if self.observer_enabled[0] == 1:
self.activation_post_process.forward(X)
- self.alpha, self.gamma, self.quantization_levels, self.level_indices = \
- self.activation_post_process.calculate_qparams(signed)
+ result = self.activation_post_process.calculate_qparams(signed=False)
+ self.alpha = result[0]
+ self.gamma = result[1]
+ self.quantization_levels = result[2]
+ self.level_indices = result[3]
+
if self.fake_quant_enabled[0] == 1:
assert (self.alpha is not None
and self.gamma is not None
diff --git a/torch/ao/quantization/experimental/observer.py b/torch/ao/quantization/experimental/observer.py
index 77cb594..244975b 100644
--- a/torch/ao/quantization/experimental/observer.py
+++ b/torch/ao/quantization/experimental/observer.py
@@ -23,7 +23,7 @@
self,
b,
k,
- dtype=torch.int32) -> None:
+ dtype=torch.quint8) -> None:
super().__init__(dtype)
self.b = b
self.k = k
@@ -47,7 +47,7 @@
quantization_levels: non-uniform quantization levels (fp representation)
level_indices: int representation of quantization_levels indices
"""
- def _calculate_qparams(self, signed, min_val=None, max_val=None):
+ def _calculate_qparams(self, signed: bool, min_val=None, max_val=None):
if min_val is not None:
self.min_val = min_val
if max_val is not None:
diff --git a/torch/ao/quantization/experimental/qconfig.py b/torch/ao/quantization/experimental/qconfig.py
new file mode 100644
index 0000000..f9397d1
--- /dev/null
+++ b/torch/ao/quantization/experimental/qconfig.py
@@ -0,0 +1,46 @@
+import torch
+from torch.ao.quantization.qconfig import QConfig
+from torch.ao.quantization import MinMaxObserver
+from torch.ao.quantization.fake_quantize import FakeQuantize
+from torch.ao.quantization.experimental.fake_quantize import APoTFakeQuantize
+
+"""
+Default symmetric fake_quant for activations.
+"""
+default_symmetric_fake_quant = FakeQuantize.with_args(observer=MinMaxObserver,
+ qscheme=torch.per_tensor_symmetric,
+ dtype=torch.quint8)
+
+"""
+Default symmetric fake_quant for weights.
+"""
+default_weight_symmetric_fake_quant = FakeQuantize.with_args(observer=MinMaxObserver,
+ qscheme=torch.per_tensor_symmetric,
+ dtype=torch.qint8)
+
+# uniform activation and weight, b=8 k=2
+uniform_qconfig_8bit = QConfig(activation=default_symmetric_fake_quant,
+ weight=default_weight_symmetric_fake_quant.with_args)
+
+# uniform activation, APoT weight, b=8 k=2
+apot_weight_qconfig_8bit = QConfig(activation=default_symmetric_fake_quant.with_args,
+ weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8))
+
+# APoT activation and uniform weight, b=8 k=2
+apot_qconfig_8bit = QConfig(activation=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.quint8),
+ weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8))
+
+# uniform activation and weight, b=4 k=2
+uniform_qconfig_4bit = QConfig(activation=default_symmetric_fake_quant.with_args(quant_min=0,
+ quant_max=15),
+ weight=default_weight_symmetric_fake_quant.with_args(quant_min=0,
+ quant_max=15))
+
+# uniform activation, APoT weight, b=4 k=2
+apot_weight_qconfig_4bit = QConfig(activation=default_symmetric_fake_quant.with_args(quant_min=0,
+ quant_max=15),
+ weight=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.qint8))
+
+# APoT activation and uniform weight, b=4 k=2
+apot_qconfig_4bit = QConfig(activation=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.quint8),
+ weight=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.qint8))
diff --git a/torch/ao/quantization/experimental/quantizer.py b/torch/ao/quantization/experimental/quantizer.py
index 3894435..1d8845c 100644
--- a/torch/ao/quantization/experimental/quantizer.py
+++ b/torch/ao/quantization/experimental/quantizer.py
@@ -1,7 +1,7 @@
import torch
from torch import Tensor
import numpy as np
-from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
+from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float, quant_dequant_util
# class to store APoT quantizer and
# implement quantize and dequantize
@@ -52,9 +52,9 @@
based on the calculated quantization levels from a specified APoT non-uniform observer.
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
Args:
- apot_tensor: quantized APoT Tensor to dequantize
+ tensor2quantize: fp Tensor
Returns:
- result: fp representation of input Tensor
+ result: fp reduced precision representation of input Tensor
"""
def dequantize(self, apot_tensor) -> Tensor:
orig_size = apot_tensor.data.size()
@@ -72,6 +72,21 @@
return result
+ r""" Returns result of quantize -> dequantize on a fp Tensor (reduced precision)
+ based on the calculated quantization levels from a specified APoT non-uniform observer.
+ The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
+ Args:
+ apot_tensor: quantized APoT Tensor to dequantize
+ Returns:
+ result: fp representation of input Tensor
+ """
+ def quant_dequant(self, tensor2quantize: Tensor) -> Tensor:
+ levels_lst = list(self.quantization_levels)
+
+ result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst))
+
+ return result
+
def q_apot_alpha(self) -> float:
raise NotImplementedError
@@ -100,3 +115,22 @@
quantizer = apot_tensor.quantizer
result = quantizer.dequantize(apot_tensor)
return result
+
+r""" Global method to create quantizer and call quantizer quant_dequant
+ Args:
+ tensor2quantize: fp Tensor to quantize
+ alpha: Tensor qparam alpha (clipping level)
+ gamma: Tensor qparam gamma (scale factor for quantization levels)
+ quantization levels: Tensor with fp quantization levels
+ level indices: Tensor with integer quantization level indices
+ Returns:
+ result: fp reduced precision Tensor from tensor2quantize
+"""
+def quant_dequant_APoT(tensor2quantize: Tensor,
+ alpha: Tensor,
+ gamma: Tensor,
+ quantization_levels: Tensor,
+ level_indices: Tensor) -> Tensor:
+ quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
+ result = quantizer.quant_dequant(tensor2quantize)
+ return result