add net transforms for fusion (#42763)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42763
add the fp16 fusions as net transforms:
-layernorm fused with mul+add
-swish int8
Test Plan: added unit test, ran flows
Reviewed By: yinghai
Differential Revision: D23002043
fbshipit-source-id: f0b13d51d68c240b05d2a237a7fb8273e996328b
diff --git a/caffe2/python/fakefp16_transform_lib.py b/caffe2/python/fakefp16_transform_lib.py
new file mode 100644
index 0000000..a550b4d
--- /dev/null
+++ b/caffe2/python/fakefp16_transform_lib.py
@@ -0,0 +1,17 @@
+#!/usr/bin/env python3
+from __future__ import division
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import caffe2.python._import_c_extension as C
+from caffe2.proto.caffe2_pb2 import NetDef
+
+def fakeFp16FuseOps(net : NetDef) -> NetDef:
+ net_str = net.SerializeToString()
+ print(dir(C))
+ out_str = C.fakeFp16FuseOps(net_str)
+ out_net = NetDef()
+ out_net.ParseFromString(out_str)
+
+ return out_net
diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc
index 767e0af..b275ccd 100644
--- a/caffe2/python/pybind_state.cc
+++ b/caffe2/python/pybind_state.cc
@@ -29,6 +29,7 @@
#include "caffe2/opt/optimize_ideep.h"
#include "caffe2/opt/passes.h"
#include "caffe2/opt/shape_info.h"
+#include "caffe2/opt/custom/fakefp16_transform.h"
#include "caffe2/predictor/emulator/data_filler.h"
#include "caffe2/predictor/predictor.h"
#include "caffe2/python/pybind_state_registry.h"
@@ -1878,6 +1879,18 @@
return py::bytes(out);
});
+ m.def("fakeFp16FuseOps", [](const py::bytes& net_str) {
+ caffe2::NetDef netDef;
+ CAFFE_ENFORCE(
+ ParseProtoFromLargeString(
+ net_str.cast<std::string>(), &netDef),
+ "broken pred_net protobuf");
+ opt::fakeFp16FuseOps(&netDef);
+ std::string out_net;
+ netDef.SerializeToString(&out_net);
+ return py::bytes(out_net);
+ });
+
auto initialize = [&]() {
// Initialization of the module
#ifdef USE_NUMPY
diff --git a/caffe2/python/test/fakefp16_transform_test.py b/caffe2/python/test/fakefp16_transform_test.py
new file mode 100644
index 0000000..e970b1c
--- /dev/null
+++ b/caffe2/python/test/fakefp16_transform_test.py
@@ -0,0 +1,23 @@
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import unittest
+from caffe2.python.fakefp16_transform_lib import fakeFp16FuseOps
+from caffe2.python import core
+
+class Transformer(unittest.TestCase):
+ def test_fuse(self):
+ net_swish = core.Net("test_swish")
+ net_swish_init = core.Net("test_swish_init")
+
+ deq = core.CreateOperator("Int8DequantizeNNPI", ["Xq"], ["X"])
+ swish = core.CreateOperator("SwishFakeFp16NNPI", ["X"], ["Y"])
+ quant = core.CreateOperator("Int8QuantizeNNPI", ["Y"], ["Y_q"])
+ net_swish.Proto().op.extend(
+ [
+ deq, swish, quant
+ ]
+ )
+ out_net = fakeFp16FuseOps(net_swish.Proto())
+ assert(len(out_net.op) == 1)