Optimize mobile model on cloned module instead of in-place transformation (#36621)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36621
Instead of doing in-place transformation inside optimizeForMobile method,
we would like to maintain the original structure for passed scriptedModule,
so before optmization starts, we use the cloned module to do subsequent optimization
process and return the optimized cloned module.
Test Plan:
unit test
python test/test_mobile_optimizer.py
Imported from OSS
Differential Revision: D21028406
fbshipit-source-id: 756172ef99b1c1df6bb7d00e5deca85a4c239a87
diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.cpp b/torch/csrc/jit/passes/xnnpack_rewrite.cpp
index 2d10870..289a7a4 100644
--- a/torch/csrc/jit/passes/xnnpack_rewrite.cpp
+++ b/torch/csrc/jit/passes/xnnpack_rewrite.cpp
@@ -296,12 +296,14 @@
PrePackingOpsFolder(m, filter_fn, "prepack_folding");
}
-void optimizeForMobile(script::Module& m) {
- m.eval();
- m = FoldConvBatchNorm2d(m);
- insertPrePackedOps(m);
- m = freeze_module(m);
- FoldPrePackingOps(m);
+c10::optional<script::Module> optimizeForMobile(const script::Module& m) {
+ auto cloned_module = m.clone();
+ cloned_module.eval();
+ cloned_module = FoldConvBatchNorm2d(cloned_module);
+ insertPrePackedOps(cloned_module);
+ cloned_module = freeze_module(cloned_module);
+ FoldPrePackingOps(cloned_module);
+ return cloned_module;
}
#else
@@ -326,10 +328,11 @@
"XNNPACK is not enabled. Please build with USE_XNNPACK=1");
}
-void optimizeForMobile(script::Module& m) {
+c10::optional<script::Module> optimizeForMobile(const script::Module& m) {
TORCH_INTERNAL_ASSERT(
"Mobile optimizaiton only available with XNNPACK at the moment. "
"XNNPACK is not enabled. Please build with USE_XNNPACK=1");
+ return c10::nullopt;
}
#endif
diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.h b/torch/csrc/jit/passes/xnnpack_rewrite.h
index e69fff1..7f0b579 100644
--- a/torch/csrc/jit/passes/xnnpack_rewrite.h
+++ b/torch/csrc/jit/passes/xnnpack_rewrite.h
@@ -9,6 +9,7 @@
TORCH_API void insertPrePackedOps(script::Module& module);
TORCH_API void fusePrePackedLinearConvWithClamp(script::Module& module);
TORCH_API void FoldPrePackingOps(script::Module& module);
-TORCH_API void optimizeForMobile(script::Module& module);
+TORCH_API c10::optional<script::Module> optimizeForMobile(
+ const script::Module& module);
} // namespace jit
} // namespace torch