rename the LayerNorm operator and add it to the replacement map (#40318)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40318
rename layernom fakefp16 to the right naming convention
add it to the map of replacement ops
this can be done even if the operator is not complete because we are blacklisting anyways
Test Plan: net_runner and inspected the log that replacement happened
Reviewed By: venkatacrc
Differential Revision: D22145900
fbshipit-source-id: f19794ec05234b877f7697ed8b05dd8f46606c47
diff --git a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.cc b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.cc
index 6eb5873..cfc3685 100644
--- a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.cc
+++ b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.cc
@@ -136,7 +136,7 @@
}
}
-REGISTER_CPU_OPERATOR(LayerNormFakeFP16, LayerNormFakeFp16Op<CPUContext>);
-OPERATOR_SCHEMA(LayerNormFakeFP16).NumInputs({1, 3}).NumOutputs(3);
+REGISTER_CPU_OPERATOR(LayerNormFakeFP16NNPI, LayerNormFakeFp16Op<CPUContext>);
+OPERATOR_SCHEMA(LayerNormFakeFP16NNPI).NumInputs({1, 3}).NumOutputs(3);
} // namespace caffe2
diff --git a/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py b/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py
index 698d7cd..920b3ff 100644
--- a/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py
+++ b/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py
@@ -63,7 +63,7 @@
pred_net_ref.external_output.extend(["Y", "mean", "rstd"])
pred_net_ref.op.add().CopyFrom(
core.CreateOperator(
- "LayerNormFakeFP16",
+ "LayerNormFakeFP16NNPI",
["X", "gamma", "beta"],
["Y", "mean", "rstd"],
axis=1,
diff --git a/caffe2/opt/custom/fakefp16_transform.cc b/caffe2/opt/custom/fakefp16_transform.cc
index 1427042..efd9c70 100644
--- a/caffe2/opt/custom/fakefp16_transform.cc
+++ b/caffe2/opt/custom/fakefp16_transform.cc
@@ -25,6 +25,7 @@
{"Int8FC", "Int8FCFakeAcc32NNPI"},
{"Int8Quantize", "Int8QuantizeNNPI"},
{"Int8Dequantize", "Int8DequantizeNNPI"},
+ {"LayerNorm", "LayerNormFakeFP16NNPI"},
{"FbFCPacked", "Fp16FCAcc32NNPI"},
{"SparseLengthsSum", "SparseLengthsSumFakeFP16AccFP16"},
{"SparseLengthsWeightedSum", "SparseLengthsWeightedSumFakeFP16AccFP16"},