fuse layernorm + quantize (#44232)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44232
enhance layernorm to optionally quantize its output
add fusion code to replace instances of layernorm +quantization
Test Plan:
tested layernorm
net_runner
P141557987
Reviewed By: venkatacrc
Differential Revision: D23510893
fbshipit-source-id: 32f57ba2090d35d86dcc951e0f3f6a8901ab3153
diff --git a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.cc b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.cc
index ebdd559..96098a4 100644
--- a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.cc
+++ b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.cc
@@ -5,8 +5,7 @@
namespace caffe2 {
-template <>
-void LayerNormFakeFp16Op<CPUContext>::calcY(
+void LayerNormUtils::calcY(
const int M,
const int N,
const float* X,
@@ -52,8 +51,7 @@
}
}
-template <>
-float LayerNormFakeFp16Op<CPUContext>::ReducedAdd(const std::vector<float>& vec) {
+float LayerNormUtils::ReducedAdd(const std::vector<float>& vec) {
constexpr int VEC_SIZE = 32;
std::vector<float> v(vec.begin(), vec.end());
@@ -69,8 +67,7 @@
return v[0];
}
-template <>
-void LayerNormFakeFp16Op<CPUContext>::calcMeanStd(
+void LayerNormUtils::calcMeanStd(
const int M,
const int N,
const float eps,
@@ -191,7 +188,14 @@
FLAGS_caffe2_fbgemm_fake_fp16_clamp);
}
-REGISTER_CPU_OPERATOR(LayerNormFakeFP16NNPI, LayerNormFakeFp16Op<CPUContext>);
+REGISTER_CPU_OPERATOR(LayerNormFakeFP16NNPI, LayerNormFakeFp16Op<false>);
OPERATOR_SCHEMA(LayerNormFakeFP16NNPI).NumInputs({1, 3}).NumOutputs(3);
+REGISTER_CPU_OPERATOR(LayerNormInt8QuantizeFakeNNPI,
+ LayerNormFakeFp16Op<true>);
+OPERATOR_SCHEMA(LayerNormInt8QuantizeFakeNNPI)
+ .IdenticalTypeAndShape()
+ .NumInputs({1, 3})
+ .NumOutputs(3);
+
} // namespace caffe2
diff --git a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h
index afa3461..4aef846 100644
--- a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h
+++ b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h
@@ -9,8 +9,6 @@
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
-//#include "caffe2/fb/fbgemm/fbgemm_fp16/include/fbgemm/FbgemmFloat16.h"
-//#include <fbgemm/FbgemmFloat16.h>
#include <fbgemm/FbgemmConvert.h>
#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"
@@ -20,11 +18,33 @@
namespace caffe2 {
-template <class Context>
-class LayerNormFakeFp16Op final : public Operator<Context> {
- public:
- USE_OPERATOR_CONTEXT_FUNCTIONS;
+class LayerNormUtils {
+ public:
+ static void calcY(
+ const int M,
+ const int N,
+ const float* X,
+ const float* mean,
+ const float* std,
+ const float* gamma,
+ const float* beta,
+ float* Y);
+
+ static void calcMeanStd(
+ const int M,
+ const int N,
+ const float eps,
+ const float* X,
+ float* mean,
+ float* std);
+
+ static float ReducedAdd(const std::vector<float>& vec);
+};
+
+template <bool quantizeOutput=false>
+class LayerNormFakeFp16Op final : public Operator<CPUContext> {
+ public:
template <class... Args>
explicit LayerNormFakeFp16Op(Args&&... args)
: Operator<CPUContext>(std::forward<Args>(args)...),
@@ -39,7 +59,14 @@
bool DoRunWithType() {
const auto& X = Input(INPUT);
- auto* Y = Output(OUTPUT, X.sizes(), at::dtype<float>());
+ vector <float> Y_fp16;
+
+ Tensor *Y;
+ if (!quantizeOutput) {
+ Y = Output(OUTPUT, X.sizes(), at::dtype<float>());
+ } else {
+ Y_fp16.resize(X.numel());
+ }
CAFFE_ENFORCE_GE(X.dim(), 2, "LayerNorm requires input dim >=2.");
const int canonical_axis = X.canonical_axis_index(axis_);
std::vector<int64_t> moments_dims(
@@ -49,9 +76,20 @@
auto* sigma = Output(STD, moments_dims, at::dtype<float>());
const int M = X.size_to_dim(canonical_axis);
const int N = X.size_from_dim(canonical_axis);
- Y->ResizeLike(X);
+
+ if (!quantizeOutput) {
+ Y->ResizeLike(X);
+ }
+
const float* X_data = X.template data<float>();
- float* Y_data = Y->template mutable_data<float>();
+
+ float *Y_data;
+ if (!quantizeOutput) {
+ Y_data = Y->template mutable_data<float>();
+ } else {
+ Y_data = Y_fp16.data();
+ }
+
float* mean_data = mean->template mutable_data<float>();
float* sigma_data = sigma->template mutable_data<float>();
@@ -65,7 +103,7 @@
X_data = X_rounded.data();
// Mean and Standard Deviation computation for the input data
- calcMeanStd(M, N, epsilon_, X_data, mean_data, sigma_data);
+ LayerNormUtils::calcMeanStd(M, N, epsilon_, X_data, mean_data, sigma_data);
const float* gamma_data = nullptr;
const float* beta_data = nullptr;
@@ -99,7 +137,7 @@
}
// Layer Normalized Output computation
- calcY(
+ LayerNormUtils::calcY(
M, N, X_data, mean_data, sigma_data, gamma_data, beta_data, Y_data);
if (InputSize() == 3 && !elementwise_affine_) {
@@ -128,30 +166,62 @@
}
}
+ // Quantize
+ // We should be using the same quantization fucntion from int8quantize,
+ // but we need to adjust for int8 vs uint8 bias. A simple shift of the output is not enough
+ // because this causes problems when rounding inside the fma.
+ // TODO: figure out how to commonize this with int8 quantize
+ if (quantizeOutput) {
+ auto* Y_int8 = Outputs()[0]->template GetMutable<int8::Int8TensorCPU>();
+ Y_int8->t.ResizeLike(X);
+
+ int32_t Y_offset =
+ this->template GetSingleArgument<int>("Y_zero_point", 0);
+ auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
+
+ float inv_scale = 1.0f / Y_scale;
+ fbgemm::RoundToFloat16(
+ &inv_scale, &inv_scale, 1, false /* no clamping */);
+
+ Y_int8->scale = Y_scale;
+ Y_int8->zero_point = Y_offset;
+
+ int Nout = X.numel();
+
+ std::vector<float> inv_scalev(Nout, inv_scale);
+ std::vector<float> offsetv(Nout, Y_offset - 128.0);
+ uint8_t* Y_uint8_data = Y_int8->t.template mutable_data<uint8_t>();
+
+ fake_fp16::fma_fp16(Nout, Y_fp16.data(), inv_scalev.data(), offsetv.data());
+
+ const int32_t qmin = std::numeric_limits<uint8_t>::min();
+ const int32_t qmax = std::numeric_limits<uint8_t>::max();
+
+ for (int i = 0; i < Nout; i++) {
+ float halfRes = offsetv[i];
+ halfRes = round(halfRes);
+ halfRes = halfRes + 128.0;
+ if (std::isinf(halfRes)) {
+ if (halfRes > 0) {
+ halfRes = qmax;
+ } else {
+ halfRes = qmin;
+ }
+ }
+ if (halfRes > qmax) {
+ halfRes = qmax;
+ }
+ if (halfRes < qmin) {
+ halfRes = qmin;
+ }
+ Y_uint8_data[i] = static_cast<uint8_t>(halfRes);
+ }
+ }
+
return true;
}
private:
- void calcY(
- const int M,
- const int N,
- const float* X,
- const float* mean,
- const float* std,
- const float* gamma,
- const float* beta,
- float* Y);
-
- void calcMeanStd(
- const int M,
- const int N,
- const float eps,
- const float* X,
- float* mean,
- float* std);
-
- float ReducedAdd(const std::vector<float>& vec);
-
const int axis_;
const float epsilon_;
const bool elementwise_affine_;
diff --git a/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py b/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py
index ec99ca4..698b839 100644
--- a/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py
+++ b/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py
@@ -120,3 +120,125 @@
}
)
assert(0)
+
+ def _get_scale_zp(self, tensor):
+ tensor_max = np.max(tensor)
+ tensor_min = min(0, np.min(tensor))
+ scale = np.float32(np.float16((tensor_max - tensor_min) / 255.0))
+ if scale < 1e-6:
+ scale = 1e-6
+ zero_point = 0 - tensor_min / scale
+ zero_point = int(round(np.clip(zero_point, 0, 255.0)))
+ return (scale, zero_point)
+
+ def _layernorm_transform(self, X):
+ mean = np.mean(X, axis=1)
+ mean_exp = np.outer(mean, np.ones(X.shape[1]))
+ std = np.std(X, axis=1)
+ std_exp = np.outer(std, np.ones(X.shape[1]))
+ Y = (X - mean_exp) / std_exp
+ return Y
+
+ @given(seed=st.integers(0, 65535),
+ batch_size=st.integers(min_value=1, max_value=50),
+ size=st.integers(min_value=2, max_value=128),
+ epsilon=st.floats(min_value=1e-4, max_value=1e-3),
+ elementwise_affine=st.booleans())
+ @settings(deadline=None)
+ # re-enable when T74553975 gets fixed
+ def Skip_test_fused_ln_quantize(self, seed, batch_size, size, epsilon, elementwise_affine):
+ np.random.seed(seed)
+
+ # Reset the workspace
+ workspace.ResetWorkspace()
+ axis = 1
+
+ dims = np.array(([batch_size, size]))
+ X = np.random.uniform(size=dims).astype(np.float32) - 0.5
+ gamma = np.random.randn(*X.shape[axis:]).astype(np.float32)
+ beta = np.random.randn(*X.shape[axis:]).astype(np.float32)
+
+ Y = self._layernorm_transform(X)
+ scale, zp = self._get_scale_zp(Y)
+
+ pred_net = caffe2_pb2.NetDef()
+ pred_net.name = "pred"
+ pred_net.external_input.extend(["X", "gamma", "beta"])
+ pred_net.external_output.extend(["Y_q"])
+ pred_net.op.add().CopyFrom(
+ core.CreateOperator(
+ "LayerNorm",
+ ["X", "gamma", "beta"] if elementwise_affine else ["X"],
+ ["Y", "mean", "rstd"],
+ axis=axis,
+ epsilon=epsilon,
+ elementwise_affine=elementwise_affine
+ )
+ )
+ pred_net.op.add().CopyFrom(
+ core.CreateOperator(
+ "Int8Quantize", ["Y"], ["Y_q"], Y_scale=scale, Y_zero_point=zp
+ )
+ )
+
+ print(pred_net)
+ pred_net_ref = caffe2_pb2.NetDef()
+ pred_net_ref.name = "pred_ref"
+ pred_net_ref.external_input.extend(["X", "gamma", "beta"])
+ pred_net_ref.external_output.extend(["Y_q"])
+ pred_net_ref.op.add().CopyFrom(
+ core.CreateOperator(
+ "LayerNormInt8QuantizeFakeNNPI",
+ ["X", "gamma", "beta"] if elementwise_affine else ["X"],
+ ["Y_q", "mean", "rstd"],
+ axis=axis,
+ epsilon=epsilon,
+ elementwise_affine=elementwise_affine,
+ Y_scale=scale, Y_zero_point=zp
+ )
+ )
+ shape_hits = {"X": X.shape, "gamma": gamma.shape, "beta": beta.shape}
+ pred_net_onnxified = onnxifi_caffe2_net(
+ pred_net,
+ shape_hits,
+ debug=True,
+ adjust_batch=True,
+ use_onnx=False
+ )
+ num_onnxified_ops = sum(
+ 1 if o.type == "Onnxifi" else 0 for o in pred_net_onnxified.op)
+ np.testing.assert_equal(num_onnxified_ops, 1)
+
+ workspace.FeedBlob("X", X)
+ workspace.FeedBlob("gamma", gamma)
+ workspace.FeedBlob("beta", beta)
+
+ workspace.CreateNet(pred_net_ref)
+ workspace.CreateNet(pred_net_onnxified)
+
+ workspace.RunNet(pred_net_ref.name)
+ Y_c2 = workspace.FetchInt8Blob("Y_q")
+
+ workspace.RunNet(pred_net_onnxified.name)
+ Y_glow = workspace.FetchInt8Blob("Y_q")
+
+ if not np.allclose(Y_glow.data, Y_c2.data) or \
+ Y_glow.scale != Y_c2.scale or Y_glow.zero_point != Y_c2.zero_point:
+ diff_Y = np.abs(Y_glow.data.astype(np.float32) - Y_c2.data.astype(np.float32))
+ print_test_debug_info(
+ "layernorm",
+ {
+ "seed": seed,
+ "size": size,
+ "batch_size": batch_size,
+ "epsilon": epsilon,
+ "gamma": gamma,
+ "beta": beta,
+ "elementwise_affine": elementwise_affine,
+ "X": X,
+ "Y_glow": Y_glow,
+ "Y_c2": Y_c2,
+ "diff_Y": diff_Y,
+ }
+ )
+ assert(0)
diff --git a/caffe2/opt/fakefp16_transform.cc b/caffe2/opt/fakefp16_transform.cc
index e9381fd..424056b 100644
--- a/caffe2/opt/fakefp16_transform.cc
+++ b/caffe2/opt/fakefp16_transform.cc
@@ -92,8 +92,8 @@
continue;
}
- const std::string& lm_output = op.output(0);
- auto next_ops = findMutableOperatorByInput(net, lm_output);
+ const std::string& ln_output = op.output(0);
+ auto next_ops = findMutableOperatorByInput(net, ln_output);
if (next_ops.size() != 1 || next_ops[0]->type() != "MulFakeFp16") {
LOG(INFO) << "next op isn't MulFakeFp16, skipping";
@@ -124,6 +124,51 @@
}
}
+void fakeFp16FoldLayerNormQuant(NetDef* net) {
+ for (auto& op : *net->mutable_op()) {
+ if (op.type() == "LayerNormFakeFP16NNPI") {
+ auto layernormNetPos = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
+ op, "net_pos", -1);
+ LOG(INFO) << "Attemping to fuse LayerNormFakeFP16NNPI w Quant at "
+ << layernormNetPos;
+ if (op.input().size() != 1) {
+ LOG(INFO) << "input isn't 1, is " << op.input().size() << " skipping";
+ continue;
+ }
+
+ const std::string& ln_output = op.output(0);
+ auto next_ops = findMutableOperatorByInput(net, ln_output);
+
+ if (next_ops.size() != 1 || next_ops[0]->type() != "Int8QuantizeNNPI") {
+ LOG(INFO) << "next op isn't Int8QuantizeNNPI, skipping";
+ continue;
+ }
+
+ auto* quantOp = next_ops[0];
+
+ if (quantOp->output().size() != 1) {
+ LOG(INFO) << "more than one output for quant, skipping";
+ continue;
+ }
+
+ op.set_type("LayerNormInt8QuantizeFakeNNPI");
+
+ *op.mutable_output(0) = quantOp->output(0);
+ op.add_arg()->CopyFrom(MakeArgument("Y_scale",
+ ArgumentHelper::GetSingleArgument<OperatorDef, float>(*quantOp, "Y_scale", -1)));
+ op.add_arg()->CopyFrom(MakeArgument("Y_zero_point",
+ ArgumentHelper::GetSingleArgument<OperatorDef, int>(*quantOp, "Y_zero_point", -1)));
+
+ auto quantNetPos = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
+ *quantOp, "net_pos", -1);
+
+ quantOp->set_type("delete_me_optimized_away");
+
+ LOG(INFO) << "Fused LayerNormFakeFP16NNPI w Quant at " << layernormNetPos << " " << quantNetPos;
+ }
+ }
+}
+
void fakeFp16FoldSwish(NetDef* net) {
// find a sequence deq->swish->quant and replace it
for (auto& op : *net->mutable_op()) {
@@ -236,6 +281,7 @@
fakeFp16FoldLayerNorm(net);
fakeFp16FoldSwish(net);
fakeFp16FoldTanhQuant(net);
+ fakeFp16FoldLayerNormQuant(net);
auto iter = net->mutable_op()->begin();
while (iter != net->mutable_op()->end()) {
diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc
index 9bceb56..9166153 100644
--- a/caffe2/opt/onnxifi_transformer.cc
+++ b/caffe2/opt/onnxifi_transformer.cc
@@ -1102,7 +1102,7 @@
int pos =
ArgumentHelper::GetSingleArgument<OperatorDef, int>(op, kNetPos, -1);
if (blocklisted_ops.count(pos)) {
- LOG(INFO) << "Skipping blocklisted op " << op.type() << " at pos " << pos;
+ LOG(INFO) << "Skipping blacklisted op " << op.type() << " at pos " << pos;
return false;
}
const OpSchema* schema = OpSchemaRegistry::Schema(op.type());