[PyTorch][ET] Collect Process Groups Mapping Info (#104373)
Summary: Add the logics and interface to log ProcessGroup comms configuration (unique ID, type, and ranks info).
Test Plan:
Testing in HPC:
```
TORCH_LOGS=all ../buck-out/v2/gen/fbcode/c8344b52091f4f7f/hpc/models/ads/__ads_10x_launcher__/ads_10x_launcher.par +launcher=local launcher.num_trainers=4 +data_loader=random data_loader.num_batches=2000
```
Example output in ET:
```
{
"name": "## process_group:init ##", "id": 3, "rf_id": 1, "parent": 2, "fw_parent": 0, "seq_id": -1, "scope": 7, "tid": 1, "fw_tid": 0, "op_schema": "",
"inputs": ["[{'pg_id': 140538064364672, 'backend_id': 140538060772480, 'backend_config': 'cuda:nccl', 'ranks': {0: 0, 1: 1, 2: 2, 3: 3}}, {'pg_id': 140538064363904, 'backend_id': 140538042628864, 'backend_config': 'cuda:nccl', 'ranks': {0: 0, 1: 1, 2: 2, 3: 3}}]"], "input_shapes": [[]], "input_types": ["String"],
"outputs": [], "output_shapes": [], "output_types": []
},
```
Differential Revision: D46321690
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104373
Approved by: https://github.com/kwen2501
diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi
index b3f9730..4c06486 100644
--- a/torch/_C/_autograd.pyi
+++ b/torch/_C/_autograd.pyi
@@ -91,7 +91,7 @@
def _add_metadata_json(key: str, value: str) -> None: ...
def _kineto_step() -> None: ...
def kineto_available() -> bool: ...
-def _record_function_with_args_enter(name: str, args: List[Any]) -> torch.Tensor: ...
+def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ...
def _record_function_with_args_exit(handle: torch.Tensor) -> None: ...
def _supported_activities() -> Set[ProfilerActivity]: ...
def _enable_record_function(enable: bool) -> None: ...
diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp
index 9d43bd5..ea7d5e8 100644
--- a/torch/csrc/distributed/c10d/Backend.hpp
+++ b/torch/csrc/distributed/c10d/Backend.hpp
@@ -53,6 +53,12 @@
return size_;
}
+ // Returns an unique opaque ID of this backend that can be used to correlate
+ // with its collectives.
+ int64_t getID() const {
+ return reinterpret_cast<std::intptr_t>(this);
+ }
+
virtual void startCoalescing() {
TORCH_CHECK(
false,
diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp
index 63d2a1a..8567dc1 100644
--- a/torch/csrc/distributed/c10d/ProcessGroup.hpp
+++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp
@@ -93,6 +93,17 @@
return size_;
}
+ // Returns an unique opaque ID of this process group object.
+ int64_t getID() const {
+ return reinterpret_cast<std::intptr_t>(this);
+ }
+
+ // Returns an unique opaque ID of a backend for the specific backend type
+ // that can correlate with this process group's collectives.
+ int64_t getBackendID(BackendType backend_type) const {
+ return reinterpret_cast<std::intptr_t>(getBackend(backend_type).get());
+ }
+
virtual const std::string getBackendName() const {
return options_->backend;
};
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index 2820b3c..19789c4 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -670,13 +670,14 @@
<< "\nTIMEOUT(ms): " << options_->timeout.count()
<< "\nUSE_HIGH_PRIORITY_STREAM: "
<< options_->is_high_priority_stream
- << "\n TORCH_DISTRIBUTED_DEBUG: "
+ << "\nTORCH_DISTRIBUTED_DEBUG: "
<< std::string(torch_distributed_debug)
- << "\n NCCL_DEBUG: " << std::string(nccl_debug);
+ << "\nNCCL_DEBUG: " << std::string(nccl_debug)
+ << "\nID=" << this->getID();
RECORD_PARAM_COMMS(
0, // seq
- reinterpret_cast<std::intptr_t>(this), // process group ptr
+ this->getID(),
rank, // rank
"init", // colName
0, // inSize
@@ -1963,7 +1964,7 @@
RECORD_PARAM_COMMS_DATA(
static_cast<int>(
this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
- reinterpret_cast<std::intptr_t>(this), // process group ptr
+ this->getID(),
tensors, // inputTensors
tensors, // outputTensors
rank_, // rank
@@ -1987,7 +1988,7 @@
RECORD_PARAM_COMMS_DATA(
static_cast<int>(
this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
- reinterpret_cast<std::intptr_t>(this), // process group ptr
+ this->getID(),
tensors, // inputTensors
tensors, // outputTensors
rank_, // rank
@@ -2014,7 +2015,7 @@
RECORD_PARAM_COMMS_DATA(
static_cast<int>(
this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
- reinterpret_cast<std::intptr_t>(this), // process group ptr
+ this->getID(),
tensors, // inputTensors
tensors, // outputTensors
rank_, // rank
@@ -2073,7 +2074,7 @@
static_cast<int>(
this->getSequenceNumberForGroup() +
1), // seq + 1 to match collective increment.
- reinterpret_cast<std::intptr_t>(this), // process group ptr
+ this->getID(),
inputTensors, // inputTensors
outputTensors, // outputTensors
rank_, // rank
@@ -2114,7 +2115,7 @@
RECORD_PARAM_COMMS_DATA(
static_cast<int>(
this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
- reinterpret_cast<std::intptr_t>(this),
+ this->getID(),
tensors, // inputTensors
tensors, // outputTensors
rank_, // rank
@@ -2177,7 +2178,7 @@
RECORD_PARAM_COMMS_DATA(
static_cast<int>(
this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
- reinterpret_cast<std::intptr_t>(this), // process group ptr
+ this->getID(),
inputTensors, // inputTensors
outputTensors, // outputTensors
rank_, // rank
@@ -2233,7 +2234,7 @@
static_cast<int>(
this->getSequenceNumberForGroup() +
1), // seq + 1 to match collective
- reinterpret_cast<std::intptr_t>(this), // process group ptr
+ this->getID(),
inputTensors, // inputTensors
outputTensors, // outputTensors
rank_, // rank
@@ -2374,7 +2375,7 @@
static_cast<int>(
this->getSequenceNumberForGroup() +
1), // seq + 1 to match collective
- reinterpret_cast<std::intptr_t>(this), // process group ptr
+ this->getID(),
inputTensors, // inputTensors
outputTensors, // outputTensors
rank_, // rank
@@ -2493,7 +2494,7 @@
RECORD_PARAM_COMMS_DATA(
static_cast<int>(
this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
- reinterpret_cast<std::intptr_t>(this), // process group ptr
+ this->getID(),
inputTensor, // inputTensor
outputTensor, // outputTensor
rank_, // rank
@@ -2572,7 +2573,7 @@
RECORD_PARAM_COMMS(
static_cast<int>(
this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
- reinterpret_cast<std::intptr_t>(this), // process group ptr
+ this->getID(),
rank_, // rank
"barrier", // colName
0, // inSize
@@ -2650,7 +2651,7 @@
static_cast<int>(
this->getSequenceNumberForGroup() +
1), // seq + 1 to match collective
- reinterpret_cast<std::intptr_t>(this), // process group ptr
+ this->getID(),
inputTensor, // inputTensor
outputTensor, // outputTensor
rank_, // rank
@@ -2691,7 +2692,7 @@
static_cast<int>(
this->getSequenceNumberForGroup() +
1), // seq + 1 to match collective
- reinterpret_cast<std::intptr_t>(this), // process group ptr
+ this->getID(),
inputTensor, // inputTensor
outputTensor, // outputTensor
rank_, // rank
@@ -2955,7 +2956,7 @@
RECORD_PARAM_COMMS_DATA(
static_cast<int>(
this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
- reinterpret_cast<std::intptr_t>(this), // process group ptr
+ this->getID(),
inputTensors, // inputTensors
outputTensors, // outputTensors
rank_, // rank
@@ -3041,7 +3042,7 @@
RECORD_PARAM_COMMS_DATA(
static_cast<int>(
this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
- reinterpret_cast<std::intptr_t>(this), // process group ptr
+ this->getID(),
inputTensors, // inputTensors
outputTensors, // outputTensors
rank_, // rank
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index 7cde2e0..d9e4064 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -1375,6 +1375,11 @@
.def("rank", &::c10d::ProcessGroup::getRank)
.def("size", &::c10d::ProcessGroup::getSize)
.def("name", &::c10d::ProcessGroup::getBackendName)
+ .def("_id", &::c10d::ProcessGroup::getID)
+ .def(
+ "_backend_id",
+ &::c10d::ProcessGroup::getBackendID,
+ py::arg("backend_type"))
.def_property_readonly("options", &::c10d::ProcessGroup::getOptions)
.def(
"broadcast",
diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp
index d2bc513..93d9e2d 100644
--- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp
+++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp
@@ -44,6 +44,8 @@
return fmt::format("[{}]", fmt::join(v, ","));
}
+std::string json_str_escape(const std::string& str);
+
constexpr size_t maxNumElements = 4096;
inline std::string getValueType(
@@ -128,10 +130,11 @@
if (str_val.size() > maxNumElements) {
LOG(WARNING) << "string size=" << str_val.size()
<< " exceeded maxNumElements=" << maxNumElements;
- return fmt::format("\"{}\"", str_val.substr(0, maxNumElements));
+ return fmt::format(
+ "\"{}\"", json_str_escape(str_val.substr(0, maxNumElements)));
}
- return fmt::format("\"{}\"", str_val);
+ return fmt::format("\"{}\"", json_str_escape(str_val));
} else if (val.isDevice()) {
return fmt::format("\"{}\"", val.toDevice().str());
}
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index 0f2cbc7..bc5d13d 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -32,7 +32,6 @@
get_debug_level,
Work
)
-from torch.autograd.profiler import record_function
from .constants import default_pg_timeout
from .c10d_logger import _exception_logger, _time_logger
from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401
@@ -181,6 +180,13 @@
MPI : ["cpu"],
}
+ backend_type_map: Dict[str, ProcessGroup.BackendType] = {
+ UNDEFINED: ProcessGroup.BackendType.UNDEFINED,
+ GLOO : ProcessGroup.BackendType.GLOO,
+ NCCL: ProcessGroup.BackendType.NCCL,
+ UCC: ProcessGroup.BackendType.UCC,
+ }
+
def __new__(cls, name: str):
if not isinstance(name, str):
raise ValueError(f"Backend name must be a string, but got: {name}")
@@ -228,6 +234,7 @@
setattr(Backend, name.upper(), name.lower())
Backend.backend_list.append(name.lower())
+ Backend.backend_type_map[name.lower()] = ProcessGroup.BackendType.CUSTOM
# Update device capability matrix in Backend class
if devices is None:
@@ -506,6 +513,27 @@
def pg_default_device(self) -> Dict[ProcessGroup, torch.device]:
return self._pg_default_device
+ @property
+ def pg_config_info(self) -> List[Dict[str, Union[int, str]]]:
+ """
+ Returns a list of dict with process groups and backends with their unique IDs
+ and configurations (types and ranks).
+ """
+ config_info = []
+ for pg, backend in self.pg_map.items():
+ # backend is a tuple with the first element being the backend type ("nccl", etc.)
+ backend_type = Backend.backend_type_map[backend[0]]
+ config_info.append(
+ {
+ "pg_id": pg._id(),
+ "backend_id": pg._backend_id(backend_type),
+ "backend_config": self.pg_backend_config[pg],
+ "ranks": self.pg_group_ranks[pg],
+ }
+ )
+ return config_info
+
+
_world = _World()
"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
@@ -3888,18 +3916,17 @@
group_name = _process_group_name(ranks, use_hashed_name=use_local_synchronization)
- with record_function(f"## process_group:init with ranks: {ranks}"):
- pg, pg_store = _new_process_group_helper(
- group_world_size,
- group_rank,
- ranks,
- backend,
- default_store,
- group_name=group_name,
- pg_options=pg_options,
- timeout=timeout,
- pg_tag=pg_tag
- )
+ pg, pg_store = _new_process_group_helper(
+ group_world_size,
+ group_rank,
+ ranks,
+ backend,
+ default_store,
+ group_name=group_name,
+ pg_options=pg_options,
+ timeout=timeout,
+ pg_tag=pg_tag
+ )
# Create the global rank to group rank mapping
_world.pg_group_ranks[pg] = {