[PyTorch][ET] Collect Process Groups Mapping Info (#104373)
Summary: Add the logics and interface to log ProcessGroup comms configuration (unique ID, type, and ranks info).
Test Plan:
Testing in HPC:
```
TORCH_LOGS=all ../buck-out/v2/gen/fbcode/c8344b52091f4f7f/hpc/models/ads/__ads_10x_launcher__/ads_10x_launcher.par  +launcher=local launcher.num_trainers=4 +data_loader=random data_loader.num_batches=2000
```
Example output in ET:
```
    {
      "name": "## process_group:init ##", "id": 3, "rf_id": 1, "parent": 2, "fw_parent": 0, "seq_id": -1, "scope": 7, "tid": 1, "fw_tid": 0, "op_schema": "",
      "inputs": ["[{'pg_id': 140538064364672, 'backend_id': 140538060772480, 'backend_config': 'cuda:nccl', 'ranks': {0: 0, 1: 1, 2: 2, 3: 3}}, {'pg_id': 140538064363904, 'backend_id': 140538042628864, 'backend_config': 'cuda:nccl', 'ranks': {0: 0, 1: 1, 2: 2, 3: 3}}]"], "input_shapes": [[]], "input_types": ["String"],
      "outputs": [], "output_shapes": [], "output_types": []
    },
```
Differential Revision: D46321690
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104373
Approved by: https://github.com/kwen2501
diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi
index b3f9730..4c06486 100644
--- a/torch/_C/_autograd.pyi
+++ b/torch/_C/_autograd.pyi
@@ -91,7 +91,7 @@
 def _add_metadata_json(key: str, value: str) -> None: ...
 def _kineto_step() -> None: ...
 def kineto_available() -> bool: ...
-def _record_function_with_args_enter(name: str, args: List[Any]) -> torch.Tensor: ...
+def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ...
 def _record_function_with_args_exit(handle: torch.Tensor) -> None: ...
 def _supported_activities() -> Set[ProfilerActivity]: ...
 def _enable_record_function(enable: bool) -> None: ...
diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp
index 9d43bd5..ea7d5e8 100644
--- a/torch/csrc/distributed/c10d/Backend.hpp
+++ b/torch/csrc/distributed/c10d/Backend.hpp
@@ -53,6 +53,12 @@
     return size_;
   }
 
+  // Returns an unique opaque ID of this backend that can be used to correlate
+  // with its collectives.
+  int64_t getID() const {
+    return reinterpret_cast<std::intptr_t>(this);
+  }
+
   virtual void startCoalescing() {
     TORCH_CHECK(
         false,
diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp
index 63d2a1a..8567dc1 100644
--- a/torch/csrc/distributed/c10d/ProcessGroup.hpp
+++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp
@@ -93,6 +93,17 @@
     return size_;
   }
 
+  // Returns an unique opaque ID of this process group object.
+  int64_t getID() const {
+    return reinterpret_cast<std::intptr_t>(this);
+  }
+
+  // Returns an unique opaque ID of a backend for the specific backend type
+  // that can correlate with this process group's collectives.
+  int64_t getBackendID(BackendType backend_type) const {
+    return reinterpret_cast<std::intptr_t>(getBackend(backend_type).get());
+  }
+
   virtual const std::string getBackendName() const {
     return options_->backend;
   };
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index 2820b3c..19789c4 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -670,13 +670,14 @@
             << "\nTIMEOUT(ms): " << options_->timeout.count()
             << "\nUSE_HIGH_PRIORITY_STREAM: "
             << options_->is_high_priority_stream
-            << "\n TORCH_DISTRIBUTED_DEBUG: "
+            << "\nTORCH_DISTRIBUTED_DEBUG: "
             << std::string(torch_distributed_debug)
-            << "\n NCCL_DEBUG: " << std::string(nccl_debug);
+            << "\nNCCL_DEBUG: " << std::string(nccl_debug)
+            << "\nID=" << this->getID();
 
   RECORD_PARAM_COMMS(
       0, // seq
-      reinterpret_cast<std::intptr_t>(this), // process group ptr
+      this->getID(),
       rank, // rank
       "init", // colName
       0, // inSize
@@ -1963,7 +1964,7 @@
   RECORD_PARAM_COMMS_DATA(
       static_cast<int>(
           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
-      reinterpret_cast<std::intptr_t>(this), // process group ptr
+      this->getID(),
       tensors, // inputTensors
       tensors, // outputTensors
       rank_, // rank
@@ -1987,7 +1988,7 @@
   RECORD_PARAM_COMMS_DATA(
       static_cast<int>(
           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
-      reinterpret_cast<std::intptr_t>(this), // process group ptr
+      this->getID(),
       tensors, // inputTensors
       tensors, // outputTensors
       rank_, // rank
@@ -2014,7 +2015,7 @@
   RECORD_PARAM_COMMS_DATA(
       static_cast<int>(
           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
-      reinterpret_cast<std::intptr_t>(this), // process group ptr
+      this->getID(),
       tensors, // inputTensors
       tensors, // outputTensors
       rank_, // rank
@@ -2073,7 +2074,7 @@
       static_cast<int>(
           this->getSequenceNumberForGroup() +
           1), // seq + 1 to match collective increment.
-      reinterpret_cast<std::intptr_t>(this), // process group ptr
+      this->getID(),
       inputTensors, // inputTensors
       outputTensors, // outputTensors
       rank_, // rank
@@ -2114,7 +2115,7 @@
   RECORD_PARAM_COMMS_DATA(
       static_cast<int>(
           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
-      reinterpret_cast<std::intptr_t>(this),
+      this->getID(),
       tensors, // inputTensors
       tensors, // outputTensors
       rank_, // rank
@@ -2177,7 +2178,7 @@
   RECORD_PARAM_COMMS_DATA(
       static_cast<int>(
           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
-      reinterpret_cast<std::intptr_t>(this), // process group ptr
+      this->getID(),
       inputTensors, // inputTensors
       outputTensors, // outputTensors
       rank_, // rank
@@ -2233,7 +2234,7 @@
         static_cast<int>(
             this->getSequenceNumberForGroup() +
             1), // seq + 1 to match collective
-        reinterpret_cast<std::intptr_t>(this), // process group ptr
+        this->getID(),
         inputTensors, // inputTensors
         outputTensors, // outputTensors
         rank_, // rank
@@ -2374,7 +2375,7 @@
         static_cast<int>(
             this->getSequenceNumberForGroup() +
             1), // seq + 1 to match collective
-        reinterpret_cast<std::intptr_t>(this), // process group ptr
+        this->getID(),
         inputTensors, // inputTensors
         outputTensors, // outputTensors
         rank_, // rank
@@ -2493,7 +2494,7 @@
   RECORD_PARAM_COMMS_DATA(
       static_cast<int>(
           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
-      reinterpret_cast<std::intptr_t>(this), // process group ptr
+      this->getID(),
       inputTensor, // inputTensor
       outputTensor, // outputTensor
       rank_, // rank
@@ -2572,7 +2573,7 @@
   RECORD_PARAM_COMMS(
       static_cast<int>(
           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
-      reinterpret_cast<std::intptr_t>(this), // process group ptr
+      this->getID(),
       rank_, // rank
       "barrier", // colName
       0, // inSize
@@ -2650,7 +2651,7 @@
         static_cast<int>(
             this->getSequenceNumberForGroup() +
             1), // seq + 1 to match collective
-        reinterpret_cast<std::intptr_t>(this), // process group ptr
+        this->getID(),
         inputTensor, // inputTensor
         outputTensor, // outputTensor
         rank_, // rank
@@ -2691,7 +2692,7 @@
         static_cast<int>(
             this->getSequenceNumberForGroup() +
             1), // seq + 1 to match collective
-        reinterpret_cast<std::intptr_t>(this), // process group ptr
+        this->getID(),
         inputTensor, // inputTensor
         outputTensor, // outputTensor
         rank_, // rank
@@ -2955,7 +2956,7 @@
   RECORD_PARAM_COMMS_DATA(
       static_cast<int>(
           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
-      reinterpret_cast<std::intptr_t>(this), // process group ptr
+      this->getID(),
       inputTensors, // inputTensors
       outputTensors, // outputTensors
       rank_, // rank
@@ -3041,7 +3042,7 @@
   RECORD_PARAM_COMMS_DATA(
       static_cast<int>(
           this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
-      reinterpret_cast<std::intptr_t>(this), // process group ptr
+      this->getID(),
       inputTensors, // inputTensors
       outputTensors, // outputTensors
       rank_, // rank
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index 7cde2e0..d9e4064 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -1375,6 +1375,11 @@
           .def("rank", &::c10d::ProcessGroup::getRank)
           .def("size", &::c10d::ProcessGroup::getSize)
           .def("name", &::c10d::ProcessGroup::getBackendName)
+          .def("_id", &::c10d::ProcessGroup::getID)
+          .def(
+              "_backend_id",
+              &::c10d::ProcessGroup::getBackendID,
+              py::arg("backend_type"))
           .def_property_readonly("options", &::c10d::ProcessGroup::getOptions)
           .def(
               "broadcast",
diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp
index d2bc513..93d9e2d 100644
--- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp
+++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp
@@ -44,6 +44,8 @@
   return fmt::format("[{}]", fmt::join(v, ","));
 }
 
+std::string json_str_escape(const std::string& str);
+
 constexpr size_t maxNumElements = 4096;
 
 inline std::string getValueType(
@@ -128,10 +130,11 @@
     if (str_val.size() > maxNumElements) {
       LOG(WARNING) << "string size=" << str_val.size()
                    << " exceeded maxNumElements=" << maxNumElements;
-      return fmt::format("\"{}\"", str_val.substr(0, maxNumElements));
+      return fmt::format(
+          "\"{}\"", json_str_escape(str_val.substr(0, maxNumElements)));
     }
 
-    return fmt::format("\"{}\"", str_val);
+    return fmt::format("\"{}\"", json_str_escape(str_val));
   } else if (val.isDevice()) {
     return fmt::format("\"{}\"", val.toDevice().str());
   }
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index 0f2cbc7..bc5d13d 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -32,7 +32,6 @@
     get_debug_level,
     Work
 )
-from torch.autograd.profiler import record_function
 from .constants import default_pg_timeout
 from .c10d_logger import _exception_logger, _time_logger
 from .rendezvous import register_rendezvous_handler, rendezvous  # noqa: F401
@@ -181,6 +180,13 @@
         MPI : ["cpu"],
     }
 
+    backend_type_map: Dict[str, ProcessGroup.BackendType] = {
+        UNDEFINED: ProcessGroup.BackendType.UNDEFINED,
+        GLOO : ProcessGroup.BackendType.GLOO,
+        NCCL: ProcessGroup.BackendType.NCCL,
+        UCC: ProcessGroup.BackendType.UCC,
+    }
+
     def __new__(cls, name: str):
         if not isinstance(name, str):
             raise ValueError(f"Backend name must be a string, but got: {name}")
@@ -228,6 +234,7 @@
 
         setattr(Backend, name.upper(), name.lower())
         Backend.backend_list.append(name.lower())
+        Backend.backend_type_map[name.lower()] = ProcessGroup.BackendType.CUSTOM
 
         # Update device capability matrix in Backend class
         if devices is None:
@@ -506,6 +513,27 @@
     def pg_default_device(self) -> Dict[ProcessGroup, torch.device]:
         return self._pg_default_device
 
+    @property
+    def pg_config_info(self) -> List[Dict[str, Union[int, str]]]:
+        """
+        Returns a list of dict with process groups and backends with their unique IDs
+        and configurations (types and ranks).
+        """
+        config_info = []
+        for pg, backend in self.pg_map.items():
+            # backend is a tuple with the first element being the backend type ("nccl", etc.)
+            backend_type = Backend.backend_type_map[backend[0]]
+            config_info.append(
+                {
+                    "pg_id": pg._id(),
+                    "backend_id": pg._backend_id(backend_type),
+                    "backend_config": self.pg_backend_config[pg],
+                    "ranks": self.pg_group_ranks[pg],
+                }
+            )
+        return config_info
+
+
 _world = _World()
 """Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
 
@@ -3888,18 +3916,17 @@
 
     group_name = _process_group_name(ranks, use_hashed_name=use_local_synchronization)
 
-    with record_function(f"## process_group:init with ranks: {ranks}"):
-        pg, pg_store = _new_process_group_helper(
-            group_world_size,
-            group_rank,
-            ranks,
-            backend,
-            default_store,
-            group_name=group_name,
-            pg_options=pg_options,
-            timeout=timeout,
-            pg_tag=pg_tag
-        )
+    pg, pg_store = _new_process_group_helper(
+        group_world_size,
+        group_rank,
+        ranks,
+        backend,
+        default_store,
+        group_name=group_name,
+        pg_options=pg_options,
+        timeout=timeout,
+        pg_tag=pg_tag
+    )
 
     # Create the global rank to group rank mapping
     _world.pg_group_ranks[pg] = {