fuse layernorm + quantize (#44232)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44232

enhance layernorm to optionally quantize its output
add fusion code to replace instances of layernorm +quantization

Test Plan:
tested layernorm
net_runner

P141557987

Reviewed By: venkatacrc

Differential Revision: D23510893

fbshipit-source-id: 32f57ba2090d35d86dcc951e0f3f6a8901ab3153
diff --git a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.cc b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.cc
index ebdd559..96098a4 100644
--- a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.cc
+++ b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.cc
@@ -5,8 +5,7 @@
 
 namespace caffe2 {
 
-template <>
-void LayerNormFakeFp16Op<CPUContext>::calcY(
+void LayerNormUtils::calcY(
     const int M,
     const int N,
     const float* X,
@@ -52,8 +51,7 @@
   }
 }
 
-template <>
-float LayerNormFakeFp16Op<CPUContext>::ReducedAdd(const std::vector<float>& vec) {
+float LayerNormUtils::ReducedAdd(const std::vector<float>& vec) {
   constexpr int VEC_SIZE = 32;
   std::vector<float> v(vec.begin(), vec.end());
 
@@ -69,8 +67,7 @@
   return v[0];
 }
 
-template <>
-void LayerNormFakeFp16Op<CPUContext>::calcMeanStd(
+void LayerNormUtils::calcMeanStd(
     const int M,
     const int N,
     const float eps,
@@ -191,7 +188,14 @@
     FLAGS_caffe2_fbgemm_fake_fp16_clamp);
 }
 
-REGISTER_CPU_OPERATOR(LayerNormFakeFP16NNPI, LayerNormFakeFp16Op<CPUContext>);
+REGISTER_CPU_OPERATOR(LayerNormFakeFP16NNPI, LayerNormFakeFp16Op<false>);
 OPERATOR_SCHEMA(LayerNormFakeFP16NNPI).NumInputs({1, 3}).NumOutputs(3);
 
+REGISTER_CPU_OPERATOR(LayerNormInt8QuantizeFakeNNPI,
+                      LayerNormFakeFp16Op<true>);
+OPERATOR_SCHEMA(LayerNormInt8QuantizeFakeNNPI)
+    .IdenticalTypeAndShape()
+    .NumInputs({1, 3})
+    .NumOutputs(3);
+
 } // namespace caffe2
diff --git a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h
index afa3461..4aef846 100644
--- a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h
+++ b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h
@@ -9,8 +9,6 @@
 #include "caffe2/core/context.h"
 #include "caffe2/core/operator.h"
 
-//#include "caffe2/fb/fbgemm/fbgemm_fp16/include/fbgemm/FbgemmFloat16.h"
-//#include <fbgemm/FbgemmFloat16.h>
 #include <fbgemm/FbgemmConvert.h>
 #include "caffe2/utils/eigen_utils.h"
 #include "caffe2/utils/math.h"
@@ -20,11 +18,33 @@
 
 namespace caffe2 {
 
-template <class Context>
-class LayerNormFakeFp16Op final : public Operator<Context> {
- public:
-  USE_OPERATOR_CONTEXT_FUNCTIONS;
 
+class LayerNormUtils {
+  public:
+  static void calcY(
+      const int M,
+      const int N,
+      const float* X,
+      const float* mean,
+      const float* std,
+      const float* gamma,
+      const float* beta,
+      float* Y);
+
+  static void calcMeanStd(
+      const int M,
+      const int N,
+      const float eps,
+      const float* X,
+      float* mean,
+      float* std);
+
+  static float ReducedAdd(const std::vector<float>& vec);
+};
+
+template <bool quantizeOutput=false>
+class LayerNormFakeFp16Op final : public Operator<CPUContext> {
+ public:
   template <class... Args>
   explicit LayerNormFakeFp16Op(Args&&... args)
       : Operator<CPUContext>(std::forward<Args>(args)...),
@@ -39,7 +59,14 @@
 
   bool DoRunWithType() {
     const auto& X = Input(INPUT);
-    auto* Y = Output(OUTPUT, X.sizes(), at::dtype<float>());
+    vector <float> Y_fp16;
+
+    Tensor *Y;
+    if (!quantizeOutput) {
+      Y = Output(OUTPUT, X.sizes(), at::dtype<float>());
+    } else {
+      Y_fp16.resize(X.numel());
+    }
     CAFFE_ENFORCE_GE(X.dim(), 2, "LayerNorm requires input dim >=2.");
     const int canonical_axis = X.canonical_axis_index(axis_);
     std::vector<int64_t> moments_dims(
@@ -49,9 +76,20 @@
     auto* sigma = Output(STD, moments_dims, at::dtype<float>());
     const int M = X.size_to_dim(canonical_axis);
     const int N = X.size_from_dim(canonical_axis);
-    Y->ResizeLike(X);
+
+    if (!quantizeOutput) {
+      Y->ResizeLike(X);
+    }
+
     const float* X_data = X.template data<float>();
-    float* Y_data = Y->template mutable_data<float>();
+
+    float *Y_data;
+    if (!quantizeOutput) {
+      Y_data = Y->template mutable_data<float>();
+    } else {
+      Y_data = Y_fp16.data();
+    }
+
     float* mean_data = mean->template mutable_data<float>();
     float* sigma_data = sigma->template mutable_data<float>();
 
@@ -65,7 +103,7 @@
     X_data = X_rounded.data();
 
     // Mean and Standard Deviation computation for the input data
-    calcMeanStd(M, N, epsilon_, X_data, mean_data, sigma_data);
+    LayerNormUtils::calcMeanStd(M, N, epsilon_, X_data, mean_data, sigma_data);
 
     const float* gamma_data = nullptr;
     const float* beta_data = nullptr;
@@ -99,7 +137,7 @@
     }
 
     // Layer Normalized Output computation
-    calcY(
+    LayerNormUtils::calcY(
         M, N, X_data, mean_data, sigma_data, gamma_data, beta_data, Y_data);
 
     if (InputSize() == 3 && !elementwise_affine_) {
@@ -128,30 +166,62 @@
       }
     }
 
+    // Quantize
+    // We should be using the same quantization fucntion from int8quantize,
+    // but we need to adjust for int8 vs uint8 bias. A simple shift of the output is not enough
+    // because this causes problems when rounding inside the fma.
+    // TODO: figure out how to commonize this with int8 quantize
+    if (quantizeOutput) {
+      auto* Y_int8 = Outputs()[0]->template GetMutable<int8::Int8TensorCPU>();
+      Y_int8->t.ResizeLike(X);
+
+      int32_t Y_offset =
+          this->template GetSingleArgument<int>("Y_zero_point", 0);
+      auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
+
+      float inv_scale = 1.0f / Y_scale;
+      fbgemm::RoundToFloat16(
+        &inv_scale, &inv_scale, 1, false /* no clamping */);
+
+      Y_int8->scale = Y_scale;
+      Y_int8->zero_point = Y_offset;
+
+      int Nout = X.numel();
+
+      std::vector<float> inv_scalev(Nout, inv_scale);
+      std::vector<float> offsetv(Nout, Y_offset - 128.0);
+      uint8_t* Y_uint8_data = Y_int8->t.template mutable_data<uint8_t>();
+
+      fake_fp16::fma_fp16(Nout, Y_fp16.data(), inv_scalev.data(), offsetv.data());
+
+      const int32_t qmin = std::numeric_limits<uint8_t>::min();
+      const int32_t qmax = std::numeric_limits<uint8_t>::max();
+
+      for (int i = 0; i < Nout; i++) {
+        float halfRes = offsetv[i];
+        halfRes = round(halfRes);
+        halfRes = halfRes + 128.0;
+        if (std::isinf(halfRes)) {
+          if (halfRes > 0) {
+            halfRes = qmax;
+          } else {
+            halfRes = qmin;
+          }
+        }
+        if (halfRes > qmax) {
+          halfRes = qmax;
+        }
+        if (halfRes < qmin) {
+          halfRes = qmin;
+        }
+        Y_uint8_data[i] = static_cast<uint8_t>(halfRes);
+      }
+    }
+
     return true;
   }
 
  private:
-  void calcY(
-      const int M,
-      const int N,
-      const float* X,
-      const float* mean,
-      const float* std,
-      const float* gamma,
-      const float* beta,
-      float* Y);
-
-  void calcMeanStd(
-      const int M,
-      const int N,
-      const float eps,
-      const float* X,
-      float* mean,
-      float* std);
-
-  float ReducedAdd(const std::vector<float>& vec);
-
   const int axis_;
   const float epsilon_;
   const bool elementwise_affine_;
diff --git a/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py b/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py
index ec99ca4..698b839 100644
--- a/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py
+++ b/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py
@@ -120,3 +120,125 @@
                 }
             )
             assert(0)
+
+    def _get_scale_zp(self, tensor):
+        tensor_max = np.max(tensor)
+        tensor_min = min(0, np.min(tensor))
+        scale = np.float32(np.float16((tensor_max - tensor_min) / 255.0))
+        if scale < 1e-6:
+            scale = 1e-6
+        zero_point = 0 - tensor_min / scale
+        zero_point = int(round(np.clip(zero_point, 0, 255.0)))
+        return (scale, zero_point)
+
+    def _layernorm_transform(self, X):
+        mean = np.mean(X, axis=1)
+        mean_exp = np.outer(mean, np.ones(X.shape[1]))
+        std = np.std(X, axis=1)
+        std_exp = np.outer(std, np.ones(X.shape[1]))
+        Y = (X - mean_exp) / std_exp
+        return Y
+
+    @given(seed=st.integers(0, 65535),
+           batch_size=st.integers(min_value=1, max_value=50),
+           size=st.integers(min_value=2, max_value=128),
+           epsilon=st.floats(min_value=1e-4, max_value=1e-3),
+           elementwise_affine=st.booleans())
+    @settings(deadline=None)
+    # re-enable when T74553975 gets fixed
+    def Skip_test_fused_ln_quantize(self, seed, batch_size, size, epsilon, elementwise_affine):
+        np.random.seed(seed)
+
+        # Reset the workspace
+        workspace.ResetWorkspace()
+        axis = 1
+
+        dims = np.array(([batch_size, size]))
+        X = np.random.uniform(size=dims).astype(np.float32) - 0.5
+        gamma = np.random.randn(*X.shape[axis:]).astype(np.float32)
+        beta = np.random.randn(*X.shape[axis:]).astype(np.float32)
+
+        Y = self._layernorm_transform(X)
+        scale, zp = self._get_scale_zp(Y)
+
+        pred_net = caffe2_pb2.NetDef()
+        pred_net.name = "pred"
+        pred_net.external_input.extend(["X", "gamma", "beta"])
+        pred_net.external_output.extend(["Y_q"])
+        pred_net.op.add().CopyFrom(
+            core.CreateOperator(
+                "LayerNorm",
+                ["X", "gamma", "beta"] if elementwise_affine else ["X"],
+                ["Y", "mean", "rstd"],
+                axis=axis,
+                epsilon=epsilon,
+                elementwise_affine=elementwise_affine
+            )
+        )
+        pred_net.op.add().CopyFrom(
+            core.CreateOperator(
+                "Int8Quantize", ["Y"], ["Y_q"], Y_scale=scale, Y_zero_point=zp
+            )
+        )
+
+        print(pred_net)
+        pred_net_ref = caffe2_pb2.NetDef()
+        pred_net_ref.name = "pred_ref"
+        pred_net_ref.external_input.extend(["X", "gamma", "beta"])
+        pred_net_ref.external_output.extend(["Y_q"])
+        pred_net_ref.op.add().CopyFrom(
+            core.CreateOperator(
+                "LayerNormInt8QuantizeFakeNNPI",
+                ["X", "gamma", "beta"] if elementwise_affine else ["X"],
+                ["Y_q", "mean", "rstd"],
+                axis=axis,
+                epsilon=epsilon,
+                elementwise_affine=elementwise_affine,
+                Y_scale=scale, Y_zero_point=zp
+            )
+        )
+        shape_hits = {"X": X.shape, "gamma": gamma.shape, "beta": beta.shape}
+        pred_net_onnxified = onnxifi_caffe2_net(
+            pred_net,
+            shape_hits,
+            debug=True,
+            adjust_batch=True,
+            use_onnx=False
+        )
+        num_onnxified_ops = sum(
+            1 if o.type == "Onnxifi" else 0 for o in pred_net_onnxified.op)
+        np.testing.assert_equal(num_onnxified_ops, 1)
+
+        workspace.FeedBlob("X", X)
+        workspace.FeedBlob("gamma", gamma)
+        workspace.FeedBlob("beta", beta)
+
+        workspace.CreateNet(pred_net_ref)
+        workspace.CreateNet(pred_net_onnxified)
+
+        workspace.RunNet(pred_net_ref.name)
+        Y_c2 = workspace.FetchInt8Blob("Y_q")
+
+        workspace.RunNet(pred_net_onnxified.name)
+        Y_glow = workspace.FetchInt8Blob("Y_q")
+
+        if not np.allclose(Y_glow.data, Y_c2.data) or \
+           Y_glow.scale != Y_c2.scale or Y_glow.zero_point != Y_c2.zero_point:
+            diff_Y = np.abs(Y_glow.data.astype(np.float32) - Y_c2.data.astype(np.float32))
+            print_test_debug_info(
+                "layernorm",
+                {
+                    "seed": seed,
+                    "size": size,
+                    "batch_size": batch_size,
+                    "epsilon": epsilon,
+                    "gamma": gamma,
+                    "beta": beta,
+                    "elementwise_affine": elementwise_affine,
+                    "X": X,
+                    "Y_glow": Y_glow,
+                    "Y_c2": Y_c2,
+                    "diff_Y": diff_Y,
+                }
+            )
+            assert(0)
diff --git a/caffe2/opt/fakefp16_transform.cc b/caffe2/opt/fakefp16_transform.cc
index e9381fd..424056b 100644
--- a/caffe2/opt/fakefp16_transform.cc
+++ b/caffe2/opt/fakefp16_transform.cc
@@ -92,8 +92,8 @@
         continue;
       }
 
-      const std::string& lm_output = op.output(0);
-      auto next_ops = findMutableOperatorByInput(net, lm_output);
+      const std::string& ln_output = op.output(0);
+      auto next_ops = findMutableOperatorByInput(net, ln_output);
 
       if (next_ops.size() != 1 || next_ops[0]->type() != "MulFakeFp16") {
         LOG(INFO) << "next op isn't MulFakeFp16, skipping";
@@ -124,6 +124,51 @@
   }
 }
 
+void fakeFp16FoldLayerNormQuant(NetDef* net) {
+  for (auto& op : *net->mutable_op()) {
+    if (op.type() == "LayerNormFakeFP16NNPI") {
+      auto layernormNetPos = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
+                             op, "net_pos", -1);
+      LOG(INFO) << "Attemping to fuse LayerNormFakeFP16NNPI w Quant at "
+                << layernormNetPos;
+      if (op.input().size() != 1) {
+        LOG(INFO) << "input isn't 1, is " << op.input().size() << " skipping";
+        continue;
+      }
+
+      const std::string& ln_output = op.output(0);
+      auto next_ops = findMutableOperatorByInput(net, ln_output);
+
+      if (next_ops.size() != 1 || next_ops[0]->type() != "Int8QuantizeNNPI") {
+        LOG(INFO) << "next op isn't Int8QuantizeNNPI, skipping";
+        continue;
+      }
+
+      auto* quantOp = next_ops[0];
+
+      if (quantOp->output().size() != 1) {
+        LOG(INFO) << "more than one output for quant, skipping";
+        continue;
+      }
+
+      op.set_type("LayerNormInt8QuantizeFakeNNPI");
+
+      *op.mutable_output(0) = quantOp->output(0);
+      op.add_arg()->CopyFrom(MakeArgument("Y_scale",
+                      ArgumentHelper::GetSingleArgument<OperatorDef, float>(*quantOp, "Y_scale", -1)));
+      op.add_arg()->CopyFrom(MakeArgument("Y_zero_point",
+                      ArgumentHelper::GetSingleArgument<OperatorDef, int>(*quantOp, "Y_zero_point", -1)));
+
+      auto quantNetPos = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
+                          *quantOp, "net_pos", -1);
+
+      quantOp->set_type("delete_me_optimized_away");
+
+      LOG(INFO) << "Fused LayerNormFakeFP16NNPI w Quant at " << layernormNetPos << " " << quantNetPos;
+    }
+  }
+}
+
 void fakeFp16FoldSwish(NetDef* net) {
   // find a sequence deq->swish->quant and replace it
   for (auto& op : *net->mutable_op()) {
@@ -236,6 +281,7 @@
   fakeFp16FoldLayerNorm(net);
   fakeFp16FoldSwish(net);
   fakeFp16FoldTanhQuant(net);
+  fakeFp16FoldLayerNormQuant(net);
 
   auto iter = net->mutable_op()->begin();
   while (iter != net->mutable_op()->end()) {
diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc
index 9bceb56..9166153 100644
--- a/caffe2/opt/onnxifi_transformer.cc
+++ b/caffe2/opt/onnxifi_transformer.cc
@@ -1102,7 +1102,7 @@
     int pos =
         ArgumentHelper::GetSingleArgument<OperatorDef, int>(op, kNetPos, -1);
     if (blocklisted_ops.count(pos)) {
-      LOG(INFO) << "Skipping blocklisted op " << op.type() << " at pos " << pos;
+      LOG(INFO) << "Skipping blacklisted op " << op.type() << " at pos " << pos;
       return false;
     }
     const OpSchema* schema = OpSchemaRegistry::Schema(op.type());