Provide option to save quantized data for DNNLOWP without layout optimization (#19681)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19681
For accelerator, we need to lower just the quantized weights data without layout transformation. This diff attempts to provide this option.
Reviewed By: jerryzh168, zrphercule
Differential Revision: D15066568
fbshipit-source-id: 133d749e087c2ad4a899bee5e96f597f70b2443c
diff --git a/caffe2/quantization/server/fbgemm_pack_op.cc b/caffe2/quantization/server/fbgemm_pack_op.cc
index 1b5386f..7b28da9 100644
--- a/caffe2/quantization/server/fbgemm_pack_op.cc
+++ b/caffe2/quantization/server/fbgemm_pack_op.cc
@@ -326,6 +326,8 @@
const OperatorDef& operator_def,
Workspace* ws)
: ConvPoolDNNLowPOpBase<uint8_t, ConvFp32Op>(operator_def, ws),
+ save_unpacked_weights_(
+ this->GetSingleArgument<bool>("save_unpacked_weights", false)),
quantize_groupwise_(
this->GetSingleArgument<bool>("quantize_groupwise", false)) {
if (this->debug_def().engine() == "DNNLOWP_ACC16") {
@@ -419,6 +421,13 @@
Y->qparams,
W_quantized,
qfactory_.get());
+ if (save_unpacked_weights_) {
+ ReinitializeTensor(&Y->original_tensor, filter.sizes(), at::dtype<int8_t>().device(CPU));
+ auto* buffer =
+ Y->original_tensor.template mutable_data<int8_t>();
+ CAFFE_ENFORCE_EQ(Y->original_tensor.numel(), W_quantized.size());
+ memcpy(buffer, W_quantized.data(), W_quantized.size() * sizeof(int8_t));
+ }
if (this->InputIsType<int8::Int8TensorCPU>(FILTER) && quantize_groupwise_) {
static int log_occurences = 0;
diff --git a/caffe2/quantization/server/fbgemm_pack_op.h b/caffe2/quantization/server/fbgemm_pack_op.h
index a2a6c9d..db7ff52 100644
--- a/caffe2/quantization/server/fbgemm_pack_op.h
+++ b/caffe2/quantization/server/fbgemm_pack_op.h
@@ -56,6 +56,9 @@
bool TakeDepthWise3x3x3FastPath_();
bool TakeGConvFastPath_();
+ // Save quantized weights right after quantization before layout packing for
+ // performance purpose
+ bool save_unpacked_weights_;
bool quantize_groupwise_;
int nbits_in_non_outlier_; // only for DNNLOWP_ACC16