Fallback to CPU when remote end does not have CUDA for profiling (#44967)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44967
When enabling profiler on server, if it is a different machine it may
not have CUDA while caller does. In this case, we would crash but now we
fallback to CPU and log a warning.
ghstack-source-id: 112977906
Test Plan: CI
Reviewed By: pritamdamania87
Differential Revision: D23790729
fbshipit-source-id: dc6eba172b7e666842d54553f52a6b9d5f0a5362
diff --git a/torch/csrc/distributed/rpc/request_callback_impl.cpp b/torch/csrc/distributed/rpc/request_callback_impl.cpp
index b68cb40..c429fde 100644
--- a/torch/csrc/distributed/rpc/request_callback_impl.cpp
+++ b/torch/csrc/distributed/rpc/request_callback_impl.cpp
@@ -502,6 +502,14 @@
}
}
+bool RequestCallbackImpl::cudaAvailable() const {
+ #ifdef USE_CUDA
+ return true;
+ #else
+ return false;
+ #endif
+}
+
} // namespace rpc
} // namespace distributed
} // namespace torch
diff --git a/torch/csrc/distributed/rpc/request_callback_impl.h b/torch/csrc/distributed/rpc/request_callback_impl.h
index 0591cc8..836e496 100644
--- a/torch/csrc/distributed/rpc/request_callback_impl.h
+++ b/torch/csrc/distributed/rpc/request_callback_impl.h
@@ -54,6 +54,8 @@
const MessageType& messageType,
const int64_t messageId,
const std::shared_ptr<FutureMessage>& responseFuture) const override;
+
+ bool cudaAvailable() const override;
};
} // namespace rpc
diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.cpp b/torch/csrc/distributed/rpc/request_callback_no_python.cpp
index f5df65e..d41c8f2 100644
--- a/torch/csrc/distributed/rpc/request_callback_no_python.cpp
+++ b/torch/csrc/distributed/rpc/request_callback_no_python.cpp
@@ -482,7 +482,26 @@
case MessageType::RUN_WITH_PROFILING_REQ: {
auto& rpcWithProfilingReq = static_cast<RpcWithProfilingReq&>(rpc);
auto wrappedMsgType = rpcWithProfilingReq.wrappedMessageType();
- const auto profilingConfig = rpcWithProfilingReq.getProfilingConfig();
+ auto profilingConfig = rpcWithProfilingReq.getProfilingConfig();
+ // If requested with CUDA from caller but CUDA is not available on this
+ // machine, fallback to CPU and log a warning instead of crashing.
+ if (profilingConfig.state ==
+ torch::autograd::profiler::ProfilerState::CUDA &&
+ !this->cudaAvailable()) {
+ profilingConfig = torch::autograd::profiler::ProfilerConfig(
+ torch::autograd::profiler::ProfilerState::CPU,
+ profilingConfig.report_input_shapes,
+ profilingConfig.profile_memory);
+
+ LOG(WARNING)
+ << "Profiler was requested to be enabled with CUDA on this node, but CUDA is not available. "
+ << "Falling back to CPU profiling only.";
+ }
+ TORCH_INTERNAL_ASSERT(
+ profilingConfig.state !=
+ torch::autograd::profiler::ProfilerState::CUDA ||
+ this->cudaAvailable(),
+ "Profiler state set to CUDA but CUDA not available.");
const auto profilingKeyId = rpcWithProfilingReq.getProfilingId();
auto wrappedRpcResponseFuture = std::make_shared<FutureMessage>();
// Enable the profiler with the config from the sender.
@@ -571,6 +590,14 @@
return createExceptionResponse(errorMsg, messageId);
}
+bool RequestCallbackNoPython::cudaAvailable() const {
+ #ifdef USE_CUDA
+ return true;
+ #else
+ return false;
+ #endif
+}
+
} // namespace rpc
} // namespace distributed
} // namespace torch
diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.h b/torch/csrc/distributed/rpc/request_callback_no_python.h
index dd54ea0..b54fe17 100644
--- a/torch/csrc/distributed/rpc/request_callback_no_python.h
+++ b/torch/csrc/distributed/rpc/request_callback_no_python.h
@@ -84,6 +84,8 @@
const std::exception& e,
const MessageType messageType,
int64_t messageId) const;
+
+ virtual bool cudaAvailable() const;
};
} // namespace rpc