[GPU] Enable optimize_for_metal in fbcode (#47102)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47102
Since the current mobile end to end workflow involves using `optmize_for_mobile` in python, the goal here is to be able to use `optmize_for_mobile(m, backend="metal")` in fbcode.
ghstack-source-id: 115749752
Test Plan:
1. Be able to export models for metal (see the next diff)
2. Make sure the change won't break the OSS workflow
3. Make sure the change won't break on the mobile bulild.
Reviewed By: xcheng16
Differential Revision: D24644422
fbshipit-source-id: bd77e22f0799533a96d048207932055fd051a67e
diff --git a/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp b/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp
index f73b18b..89f08e2 100644
--- a/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp
+++ b/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp
@@ -1,6 +1,7 @@
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/native/metal/MetalPrepackOpContext.h>
-#include <torch/script.h>
+#include <ATen/ATen.h>
+
#if defined(C10_IOS)
#import <ATen/native/metal/mpscnn/MPSCNNOps.h>
diff --git a/torch/csrc/jit/passes/metal_rewrite.cpp b/torch/csrc/jit/passes/metal_rewrite.cpp
index 21f9f32..ffc0000 100644
--- a/torch/csrc/jit/passes/metal_rewrite.cpp
+++ b/torch/csrc/jit/passes/metal_rewrite.cpp
@@ -15,8 +15,6 @@
namespace torch {
namespace jit {
-#ifdef USE_PYTORCH_METAL
-
namespace {
void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
@@ -160,16 +158,14 @@
void metalInsertCopyOps(script::Module& module) {
auto graph = module.get_method("forward").graph();
auto&& outputs = graph->outputs();
- for (int i = 0; i < outputs.size(); ++i) {
+ for (size_t i = 0; i < outputs.size(); ++i) {
Value* output = outputs[i];
- std::cout << "find output: " << *output->node() << std::endl;
auto namedValue = NamedValue("", output);
if (namedValue.type()->kind() == TypeKind::TensorType) {
// find the insertion point
WithInsertPoint ip(output->node()->next());
Value* replaced_output = graph->insert(
Symbol::fromQualString("metal::copy_to_host"), {namedValue});
- std::cout << "insert: " << *replaced_output->node() << std::endl;
// replaced the output
graph->block()->replaceOutput(i, replaced_output);
}
@@ -195,37 +191,5 @@
return cloned_module;
}
-#else
-
-void metalInsertPrePackedOps(std::shared_ptr<Graph>& graph) {
- TORCH_INTERNAL_ASSERT(
- "metal is not enabled. Please build with USE_PYTORCH_METAL=1");
-}
-
-void metalInsertPrePackedOps(script::Module& module) {
- TORCH_INTERNAL_ASSERT(
- "metal is not enabled. Please build with USE_PYTORCH_METAL=1");
-}
-
-TORCH_API void metalFusePrePackedConvWithClamp(script::Module& module) {
- TORCH_INTERNAL_ASSERT(
- "metal is not enabled. Please build with USE_PYTORCH_METAL=1");
-}
-
-TORCH_API void metalFoldPrePackingOps(script::Module& module) {
- TORCH_INTERNAL_ASSERT(
- "metal is not enabled. Please build with USE_PYTORCH_METAL=1");
-}
-
-script::Module metalOptimizeForMobile(
- const script::Module& m,
- const std::vector<std::string>& preserved_methods) {
- TORCH_INTERNAL_ASSERT(
- "Mobile optimizaiton only available with metal at the moment. "
- "metal is not enabled. Please build with USE_PYTORCH_METAL=1");
- return m;
-}
-
-#endif
} // namespace jit
} // namespace torch