blob: 2f25d0ecd642ef054d546ecddcc0bbfaac9b5104 [file] [log] [blame]
from __future__ import absolute_import, division, print_function, unicode_literals
import torch.nn as nn
import torch.nn.functional as F
class Hardswish(nn.Hardswish):
r"""
A Hardswish module attached with FakeQuantize modules for both output
activation and weight, used for quantization aware training.
Similar to `torch.nn.Hardswish`, with FakeQuantize modules initialized to
default.
Attributes:
activation_post_process: fake quant module for output activation
weight: fake quant module for weight
"""
_FLOAT_MODULE = nn.Hardswish
def __init__(self, qconfig=None):
super(Hardswish, self).__init__()
assert qconfig, 'qconfig must be provided for QAT module'
self.qconfig = qconfig
self.activation_post_process = qconfig.activation()
def forward(self, input):
return self.activation_post_process(F.hardswish(input))
@classmethod
def from_float(cls, mod, qconfig=None):
r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, ' qat.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__
if not qconfig:
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
qconfig = mod.qconfig
qat_hardswish = cls(qconfig=qconfig)
return qat_hardswish