Initial commit of flight recorder trace (#130764)
Summary:
`fr_trace.py` is used to analyze flight recorder dump files.
This script was taken from @wconstab and @zdevito.
Only minor changes made were to make the linter happy and add a few odd new fields that I added in version `2.2` of the collector portions.
Test Plan:
Tested manually on some flight recorder data and it seems to run.
TODO:
Address 15 odd `#type: ignore` that I put in there to make the linter happy for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130764
Approved by: https://github.com/fduwjj
diff --git a/tools/flight_recorder/fr_trace.py b/tools/flight_recorder/fr_trace.py
new file mode 100644
index 0000000..dc05ea6
--- /dev/null
+++ b/tools/flight_recorder/fr_trace.py
@@ -0,0 +1,828 @@
+#!/usr/bin/env python3
+"""Flight Recorder Trace Analyzer
+
+This script primarily merges data from individual flight recorder buffers from individual ranks in a
+PyTorch Distributed program into a flattened database format that can be used for further analysis.
+
+However as part of the merging process, it is necessary to perform some analysis in order to match operators
+on one rank with corresponding operators on other ranks and register them as one 'collective' entry. During this
+process, a significant amount of useful information can already be extracted such as where the first mismatch occurs
+in cases of desync (when not all ranks issue a compatible collective in a particular process group).
+
+
+Not Yet Implemented
+- TODO- tracebacks aren't implemented
+
+Known Issues
+- Flight Recorder buffer sequence_id information is not sufficient to match collectives and coalseced collectives
+ unless we have the trace data from the beginning of the program. To enable confident analysis of trace buffers that
+ do not start from zero (and to simplify the script's matching logic) we need to add more information to the recorder.
+- Currently, the script omits checking the 'status' of collectives. We can look for the first 'non completed'
+ collective easily enough and report that.
+
+Usage
+python fr_trace.py -d <dump dir containing trace files> [-o <output file>]
+
+- Omitting the optional output file will still yield analysis information to stdout
+- the output file is a pickle of the flat DB, which may change in format in the future.
+"""
+
+import argparse
+import os
+import pickle
+import sys
+from typing import ( # type: ignore[attr-defined]
+ _eval_type,
+ Any,
+ Dict,
+ Generic,
+ List,
+ NamedTuple,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
+
+import tabulate # type: ignore[import-untyped]
+
+
+T = TypeVar("T", bound=NamedTuple)
+
+
+class Ref(Generic[T]):
+ pass
+
+
+class TypeInfo(NamedTuple):
+ name: str
+ fields: List[Tuple[str, Type]] # type: ignore[type-arg]
+
+ @classmethod
+ def from_type(cls, c: T) -> "TypeInfo":
+ if hasattr(c, "__name__"):
+ name = c.__name__
+ else:
+ name = str(c)
+ return cls(
+ name,
+ [(f, _eval_type(c.__annotations__[f], globals(), {})) for f in c._fields],
+ )
+
+
+"""
+Schema for flat DB
+
+TODO schemas not yet implemented
+# threads as recorded at termination of process
+Threads
+ id: int
+ traceback_id: int
+ process_id: int
+
+Process:
+ id: int # Same as world groups RANK
+ pid: int
+ hostname: str
+
+NCCLOp:
+ # nccl op implementation details (sends/recv)
+ id: int
+ nccl_call_id: int
+
+"""
+
+
+class Group(NamedTuple):
+ id: int
+ desc: str
+ size: int
+
+
+class Membership(NamedTuple):
+ group_id: Ref[Group]
+ global_rank: int
+
+
+class Traceback(NamedTuple):
+ id: int
+ frames: str
+
+
+class Collective(NamedTuple):
+ id: int
+ group_id: Ref[Group]
+
+
+class NCCLCall(NamedTuple):
+ id: int
+ collective_id: Ref[Collective]
+ group_id: Ref[Group]
+ global_rank: int # technically Ref[Process] once we have it
+ traceback_id: Ref[Traceback]
+ collective_type: str
+ sizes: List[List[int]]
+
+
+class Database(NamedTuple):
+ groups: List[Group]
+ memberships: List[Membership]
+ tracebacks: List[Traceback]
+ collectives: List[Collective]
+ ncclcalls: List[NCCLCall]
+
+
+types = [
+ TypeInfo.from_type(t) # type: ignore[type-var]
+ for t in globals().values()
+ if (
+ isinstance(t, type)
+ and issubclass(t, tuple)
+ and hasattr(t, "_fields")
+ and t is not TypeInfo
+ )
+]
+
+"""
+Stacktrace cache
+TODO
+"""
+
+
+"""
+Collective Matching logic
+"""
+COLLECTIVES = {
+ "broadcast",
+ "all_gather",
+ "all_reduce",
+ "_all_gather_base",
+ "all_gather_into_tensor_coalesced",
+ "reduce_scatter_tensor_coalesced",
+ "_reduce_scatter_base",
+ "gather",
+ "scatter",
+}
+
+P2P = {
+ "send",
+ "recv",
+}
+
+
+class Op:
+ """Parses relevant info about operation out of 'event' dict
+
+ examples of supported `profiling_name`s:
+ nccl:broadcast
+ nccl:send 1->2
+ nccl:recv 3<-0
+ """
+
+ def __init__(self, event: Dict[Any, Any], memberships: Dict[str, List[Membership]]):
+ profiling_name = event["profiling_name"]
+ nccl, name = profiling_name.split(":")
+ assert nccl == "nccl", f"name formatting error? {nccl} != 'nccl'"
+ parts = name.split(" ")
+ type = parts[0]
+ meta = parts[1] if len(parts) == 2 else None
+ self.state = event["state"]
+
+ self.pg_name, _ = event["process_group"]
+
+ assert type in COLLECTIVES | P2P | {
+ "coalesced"
+ }, f"{type} is not a supported operation"
+ self.type = type
+ if type == "send":
+ s, d = meta.split("->")
+ self._src, self._dst = int(s), int(d)
+ elif type == "recv":
+ d, s = meta.split("<-")
+ self._dst, self._src = int(d), int(s)
+ else:
+ self._src, self._dst = -1, -1
+ pg_name, pg_desc = event["process_group"]
+ self._init_global_src_dst(memberships[pg_name])
+
+ if type in P2P | COLLECTIVES:
+ self.input_sizes = event["input_sizes"]
+ self.output_sizes = event["output_sizes"]
+ else:
+ self.input_sizes, self.output_sizes = None, None
+ self.collective_seq_id = event["collective_seq_id"]
+ self.p2p_seq_id = event["p2p_seq_id"]
+
+ def _init_global_src_dst(self, pg_ranks: List[Membership]) -> None:
+ pg_ranks = sorted(pg_ranks)
+ self._src_g = pg_ranks[self._src] if self._src is not None else None
+ self._dst_g = pg_ranks[self._dst] if self._dst is not None else None
+
+ @property
+ def src(self) -> int:
+ assert self.type in P2P, "can't get src of non-p2p op"
+ return self._src
+
+ @property
+ def dst(self) -> int:
+ assert self.type in P2P, "can't get dst of non-p2p op"
+ return self._dst
+
+ def __repr__(self) -> str:
+ if self.type in P2P:
+ return (
+ f"{self.type}(s={self._src_g} d={self._dst_g}, sz={self.input_sizes})"
+ )
+ return f"{self.type}(input_sizes={self.input_sizes}, {self.state})"
+
+ def match(self, other) -> bool: # type: ignore[no-untyped-def]
+ # TODO: I think this can validly not match,
+ # e.g. if one PG was used for p2p ops between only some of the peers?
+ # if self.seq_id != other.seq_id:
+ # return False
+
+ if self.type == "send":
+ return bool(
+ other.type == "recv"
+ and self.src == other.src
+ and self.dst == other.dst
+ and self.input_sizes == other.output_sizes
+ )
+ elif self.type == "recv":
+ return bool(
+ other.type == "send"
+ and self.src == other.src
+ and self.dst == other.dst
+ and self.output_sizes == other.input_sizes
+ )
+ elif self.type in COLLECTIVES:
+ return bool(
+ self.type == other.type and self.input_sizes == other.input_sizes
+ )
+ # TODO(whc) - output sizes dont have to match for e.g. gather, not sure if they ever have to match?
+ # and self.output_sizes == other.output_sizes)
+ elif self.type == "coalesced":
+ return bool(other.type == "coalesced")
+ return False
+
+
+def match_one_event(
+ event_a: Dict[Any, Any],
+ event_b: Dict[Any, Any],
+ memberships: Dict[str, List[Membership]],
+) -> bool:
+ op_a = Op(event_a, memberships)
+ op_b = Op(event_b, memberships)
+ return op_a.match(op_b)
+
+
+def match_coalesced_groups(
+ all_rank_events: Dict[Any, Any],
+ group_size: int,
+ groups: Dict[str, Group],
+ memberships: Dict[str, List[Membership]],
+) -> bool:
+ """
+ all_rank_events: {
+ rank: [
+ (idx, event_dict)
+ ]
+ }
+
+ Note: it is possible for event dicts in a coalesced group to be asymmetric.
+ e.g. the following events lists form a valid coalescing group
+ events0 [send:1]
+ events1 [recv:0, send:2]
+ events2 [recv:1]
+
+ Rule 1: all ops should find a match
+ Rule 2: relative ordering of sends and recvs in one event list can be arbitrary
+ e.g.
+ events1 [recv:0, send:2] —> okay
+ events1 [send:2, recv:0] —> also okay
+ Rule 3: sends to the same dest or recvs from the src should be in a consistent order
+ e.g.
+ rank0 [send:1 (100B), send:1 (1000B)]
+ rank1 [recv:0 (1000B), recv:0 (100B)] —> not okay
+ """
+ all_ops = {
+ rank: [Op(e, memberships) for i, e in all_rank_events[rank]]
+ for rank in all_rank_events
+ }
+
+ def visualize_ops(match: bool) -> None:
+ all_ops = {
+ rank: [Op(e, memberships) for i, e in all_rank_events[rank]]
+ for rank in all_rank_events
+ }
+
+ i = 0
+ row = []
+ progress = True
+ table = []
+ while progress:
+ progress = False
+ for r in all_ops:
+ if len(all_ops[r]) > i:
+ _, event = all_rank_events[r][i]
+ row.append(Op(event, memberships))
+ progress = True
+ else:
+ row.append(None) # type: ignore[arg-type]
+ table.append(row)
+ row = []
+ i += 1
+ title = "Match" if match else "MISMATCH"
+ print(f"{title}\n", tabulate(table)) # type: ignore[operator]
+
+ # TODO can't verify seq_id bc there might have been valid seq deltas between ranks even within a pg.
+ for op_list in all_ops.values():
+ if not op_list:
+ # print("TODO- not sure if its valid for only some ranks in a PG to participate in a coalesced op?")
+ return False
+ assert op_list[-1].type == "coalesced"
+ op_list.pop(-1)
+
+ while all_ops:
+ first_rank = next(iter(all_ops))
+ my_ops = all_ops[first_rank]
+
+ if len(all_ops[first_rank]) == 0:
+ all_ops.pop(first_rank)
+ continue
+
+ # lets match the first collective! we need to know which ranks are involved, and ensure that this same
+ # collective is also the first one on those ranks within that group
+ op = my_ops[0]
+ match_idx = -1
+ if op.type in P2P:
+ dst_global_rank = sorted(memberships[op.pg_name])[op.dst]
+ peer_ops = all_ops[dst_global_rank]
+ for i, other in enumerate(peer_ops):
+ if op.match(other):
+ match_idx = i
+ break
+ elif op.dst == other.src:
+ # Rule 3
+ break
+ else:
+ # Rule 1
+ continue
+ else:
+ raise NotImplementedError("coalesced collective ops")
+ if match_idx >= 0:
+ my_ops.pop(0)
+ peer_ops.pop(match_idx)
+ else:
+ visualize_ops(False)
+ return False
+
+ visualize_ops(True)
+ return True
+
+
+"""
+Flat DB builder
+"""
+
+
+def build_groups_memberships(
+ pg_config: Any,
+) -> Tuple[List[Group], Dict[Any, Group], List[Membership], Dict[str, Any]]:
+ """
+ pg_config: {
+ global_rank: {
+ (pg_id, desc, ranks)
+ }
+ }
+
+ `pg_id` is a system generated id, but depending on the mode of PG creation it could be a globally incrementing int
+ or a hash of the ranks. See `_process_group_name` in distributed_c10d.py.
+ `desc` is provided by the user (optionally) and should be 'meaningful' (e.g. TP/PP/DP group)
+ `ranks` is a list of the 'global ranks' that are members of the PG.
+
+ (pg_id, desc, ranks) tuples are appended lazily to the flight buffer when `getNCCLComm` is called on a PG and
+ the `enabled_` flag is true for that PG.
+ - the order of calling (init_process_group, new_group, etc) does not affect the order of the tuples in the list
+
+ Returns: a groups table and a membership table, where each row is a Group or Membership namedtuple
+ """
+ # flat lists for return
+ groups = []
+ memberships = []
+
+ # dicts for faster cross-rank validation
+ _groups = {}
+ _memberships = {}
+ for global_rank in pg_config:
+ for pg_id in pg_config[global_rank]:
+ desc = pg_config[global_rank][pg_id]["desc"]
+ ranks = pg_config[global_rank][pg_id]["ranks"]
+ if isinstance(ranks, str):
+ # TODO Bug in FR data format? ranks is '[0, 1,...]'
+ ranks = eval(ranks)
+
+ if pg_id not in _groups:
+ groups.append(Group(id=pg_id, desc=desc, size=len(ranks)))
+ for rank in ranks:
+ memberships.append(Membership(group_id=pg_id, global_rank=rank))
+ _groups[pg_id] = groups[-1]
+ _memberships[pg_id] = set(ranks)
+ else:
+ # validation across ranks
+ assert (
+ _groups[pg_id].desc == desc
+ ), f"mismatch in desc {_groups[pg_id].desc} vs {desc}"
+ assert _memberships[pg_id] == set(
+ ranks
+ ), f"mismatch in membership {_memberships[pg_id]} vs {set(ranks)}"
+ return groups, _groups, memberships, _memberships
+
+
+def build_nccl_call(
+ entry: Dict[Any, Any],
+ id: int,
+ collective_id: Any,
+ group_id: int,
+ global_rank: Any,
+) -> NCCLCall:
+ return NCCLCall(
+ id=id,
+ collective_id=collective_id,
+ group_id=group_id, # type: ignore[arg-type]
+ global_rank=global_rank,
+ traceback_id=0, # type: ignore[arg-type]
+ collective_type=entry["profiling_name"],
+ sizes=entry["input_sizes"],
+ )
+
+
+def find_coalesced_group(
+ pg_name: str, entries: List[Dict[str, Any]]
+) -> List[Tuple[int, Dict[str, Any]]]:
+ """Given a list of entries, if the collective_seq_id of the first entry matches that of subsequent ones,
+ build an return a list of entries terminating in a 'coalesced' op entry all sharing a collective_seq_id
+ TODO: handle p2p_seq_id v/s collective_seq_id separately here.
+ """
+ found = []
+ collective_seq_id = None
+ for i, e in enumerate(entries):
+ if e["process_group"][0] != pg_name:
+ continue
+ elif collective_seq_id is None:
+ collective_seq_id = e["collective_seq_id"]
+ found.append((i, e))
+ elif e["collective_seq_id"] == collective_seq_id:
+ found.append((i, e))
+ else:
+ break
+
+ if len(found) > 1:
+ assert found[-1][1]["profiling_name"] == "nccl:coalesced"
+ return found
+ return []
+
+
+def just_print_entries(
+ all_entries: Dict[int, List[Dict[str, Any]]],
+ _groups: Dict[str, Group],
+ _memberships: Dict[str, List[Membership]],
+) -> None:
+ rows = []
+ ranks = sorted(all_entries.keys())
+ headers = [f"Rank {rank}" for rank in ranks]
+ progress = True
+ while progress:
+ progress = False
+ row = []
+ for rank in ranks:
+ if len(all_entries[rank]) == 0:
+ row.append("")
+ else:
+ entry = all_entries[rank].pop(0)
+ row.append(str(Op(entry, _memberships)))
+ progress = True
+ if progress:
+ rows.append(row)
+
+ print(tabulate(rows, headers=headers)) # type: ignore[operator]
+
+
+def build_collectives(
+ all_entries: Dict[int, List[Dict[str, Any]]],
+ _groups: Dict[str, Group],
+ _memberships: Dict[str, List[Membership]],
+) -> Tuple[List[Traceback], List[Collective], List[NCCLCall]]:
+ """
+ groups, memberships are the non-flat dicts that are indexable
+ all_entries is a raw dict from the original dumps:
+
+ all_entries: {
+ global_rank: [
+ {
+ record_id: ordered id of the event in the trace buffer
+ pg_id: ProcessGroupNCCL::uid_
+ *note: `pg_id` corresponds to nothing in groups table
+ process_group: (pg_name, desc)
+ *note: `pg_name`, `desc` corresponds to `pg_id`, `desc` in groups table
+ collective_seq_id: ordered id for collective operations and coalesced group operations
+ p2p_seq_id: ordered id for point-to-point operations
+ op_id: ordered id including individual ops inside coalescing group
+ profiling_name: descriptive name of the operation
+ 'time_created_ns',
+ 'input_sizes',
+ 'output_sizes',
+ 'state',
+ 'time_discovered_started_ns',
+ 'time_discovered_completed_ns',
+ 'retired',
+ 'frames',
+ }
+ ]
+ }
+ """
+ tracebacks: List[Traceback] = []
+
+ collectives: List[Collective] = []
+ nccl_calls: List[NCCLCall] = []
+
+ # once we find one mismatch, we stop pairing up collectives since the pairing is possibly incorrect
+ # instead, just record the remaining ops as NCCLCalls
+ mismatch = {_groups[g].id: 0 for g in _groups}
+ MISMATCH_TAIL = 10
+ """
+ - it doesn't matter what order I put collectives/ncclops into their table. we can later on re-sort it by start time
+ - there could be multiple options for the "first" collective to pair up (rank 0,1 might do a bcast while rank 2,3 do a bcast)
+ - within a group, the first collective must be the same on all ranks in the group, then it can be marked as a
+ collective and removed
+ """
+ while all_entries:
+ # we greedily match collectives, starting arbitrarily with the trace from the first rank
+ # later, if we exhaust the first rank, we continue with the next 'first rank'
+ rank_iter = iter(all_entries)
+ first_rank = next(rank_iter)
+ other_ranks = list(rank_iter)
+
+ if len(all_entries[first_rank]) == 0:
+ all_entries.pop(first_rank)
+ continue
+
+ # lets match the first collective! we need to know which ranks are involved, and ensure that this same
+ # collective is also the first one on those ranks within that group
+ entries = all_entries[first_rank]
+ pg_name, desc = entries[0]["process_group"]
+ profiling_name = entries[0]["profiling_name"]
+ expected_ranks = set(_memberships[pg_name])
+ found_ranks = {first_rank}
+ found_idx = {}
+
+ if find_coalesced_group(pg_name, entries):
+ expected_ranks.add(first_rank)
+ done_ranks = set()
+ all_coalesced_entries = {}
+ while expected_ranks:
+ curr = expected_ranks.pop()
+ done_ranks.add(curr)
+ grp = (
+ find_coalesced_group(pg_name, all_entries[curr]) # type: ignore[index]
+ if curr in all_entries # type: ignore[comparison-overlap]
+ else []
+ )
+ all_coalesced_entries[curr] = grp
+ for index, entry in grp:
+ op = Op(entry, _memberships)
+ peer = None
+ if op.type == "send":
+ assert op._src_g == curr, (op._src_g, curr)
+ peer = op._dst_g
+ elif op.type == "recv":
+ assert op._dst_g == curr, (op._dst_g, curr)
+ peer = op._src_g
+ if peer and peer not in done_ranks:
+ expected_ranks.add(peer)
+
+ match = match_coalesced_groups(
+ all_coalesced_entries,
+ group_size=_groups[pg_name].size,
+ groups=_groups,
+ memberships=_memberships,
+ )
+
+ if match and mismatch[pg_name] == 0:
+ collectives.append(Collective(id=len(collectives), group_id=pg_name))
+ else:
+ mismatch[pg_name] += 1
+
+ for r in all_coalesced_entries:
+ reversed_calls = []
+ for i, _ in reversed(all_coalesced_entries[r]):
+ reversed_calls.append(
+ build_nccl_call(
+ all_entries[r].pop(i), # type: ignore[index]
+ id=len(nccl_calls),
+ collective_id=collectives[-1].id if match else None,
+ group_id=pg_name,
+ global_rank=r,
+ )
+ )
+ nccl_calls.extend(reversed(reversed_calls))
+
+ else:
+ for o in expected_ranks.intersection(set(other_ranks)):
+ for i, e in enumerate(all_entries[o]): # type: ignore[index]
+ # step over ops from other PGs
+ if e["process_group"] == (pg_name, desc):
+ if (
+ match_one_event(entries[0], e, _memberships)
+ and mismatch[pg_name] == 0
+ ):
+ found_ranks.add(o)
+ found_idx[o] = i
+ else:
+ # we found a mismatch. what do we do with that?
+ mismatch[pg_name] += 1
+ print(
+ f"Mismatched collective on rank {o} for group {pg_name}:{desc} collective {profiling_name}"
+ )
+ break
+
+ # at this point there are 3 possibilities
+ # 1. we found a match on all the ranks that are members of the group
+ # -> we create a Collective and remove the individual entries from their original lists
+ if found_ranks == expected_ranks and mismatch[pg_name] == 0:
+ collectives.append(Collective(id=len(collectives), group_id=pg_name))
+ for r in found_ranks:
+ i = found_idx[r] if r != first_rank else 0
+ nccl_calls.append(
+ build_nccl_call(
+ all_entries[r].pop(i), # type: ignore[index]
+ id=len(nccl_calls),
+ collective_id=collectives[-1].id,
+ group_id=pg_name,
+ global_rank=r,
+ )
+ )
+
+ # 2. we found a partial match but some ranks are missing
+ # 3. we found no match
+ # -> since its not a complete collective, no entry goes into collectives but we still record a nccl call
+ # TODO should there be a way to mark 'mismatches'?
+ else:
+ print("appending a non-matching collective")
+ nccl_calls.append(
+ build_nccl_call(
+ all_entries[first_rank].pop(0),
+ id=len(nccl_calls),
+ collective_id=None,
+ group_id=pg_name,
+ global_rank=r,
+ )
+ )
+
+ if mismatch[pg_name] > MISMATCH_TAIL:
+ print(f"Too many mismatches for process_group {pg_name}:{desc}, aborting")
+ sys.exit(-1)
+
+ return tracebacks, collectives, nccl_calls
+
+
+def check_no_missing_dump_files(
+ entries: Dict[str, Any], memberships: List[Membership]
+) -> None:
+ all_ranks = set()
+ for membership in memberships:
+ all_ranks.add(str(membership.global_rank))
+ dumps_ranks = set(entries.keys())
+ assert (
+ dumps_ranks == all_ranks
+ ), f"Missing dump files from ranks {all_ranks - dumps_ranks}"
+
+
+def check_version(versions: Dict[str, Any]) -> None:
+ for rank, version in versions.items(): # noqa: PERF102
+ major, minor = map(int, version.split("."))
+ # assert major == 2, f"Rank {rank} unsupported version {version}"
+ # assert minor >= 0, f"Rank {rank} unsupported version {version}"
+
+
+def check_trace_from_beginning(entries: Dict[str, Any]) -> bool:
+ for rank in entries:
+ first_record_id = entries[rank][0]["record_id"]
+ # TODO add more sequence information such that analysis can proceed even without complete buffer
+
+ # assert first_record_id == 0, f"Rank {rank} trace does not start at time 0 (first record is {first_record_id}."
+ if first_record_id != 0:
+ print(
+ f"Rank {rank} trace does not start at time 0 (first record is {first_record_id}."
+ )
+ return False
+ return True
+
+
+def build_db(details: Dict[str, Dict[str, Any]], args: argparse.Namespace) -> Database:
+ # temporary state used for building database
+ entries = {}
+ pg_config = {}
+ version = {}
+ for dump in details.values():
+ rank = dump["rank"]
+ entries[rank] = dump["entries"]
+ version[rank] = dump["version"]
+ pg_config[rank] = dump["pg_config"]
+
+ check_version(version)
+ check_trace_from_beginning(entries)
+
+ # flattened database
+ groups, _groups, memberships, _memberships = build_groups_memberships(pg_config)
+ print("built groups, memberships")
+
+ check_no_missing_dump_files(entries, memberships)
+
+ if args.just_print_entries:
+ just_print_entries(entries, _groups, _memberships)
+ sys.exit(0)
+
+ tracebacks, collectives, nccl_calls = build_collectives(
+ entries, _groups, _memberships
+ )
+ print("built collectives, nccl_calls")
+ if args.verbose:
+ print("Groups\n", tabulate(groups, headers=Group._fields)) # type: ignore[operator]
+ print("Memberships\n", tabulate(memberships, headers=Membership._fields)) # type: ignore[operator]
+ print("Collectives\n", tabulate(collectives, headers=Collective._fields)) # type: ignore[operator]
+ print("NCCLCalls\n", tabulate(nccl_calls, headers=NCCLCall._fields)) # type: ignore[operator]
+ db = Database(
+ tracebacks=tracebacks,
+ collectives=collectives,
+ ncclcalls=nccl_calls,
+ groups=groups,
+ memberships=memberships,
+ )
+ return db
+
+
+def read_dump(prefix: str, filename: str) -> Dict[str, Union[str, int, List[Any]]]:
+ basename = os.path.basename(filename)
+ assert (
+ basename.find(prefix) == 0
+ ), f"args.prefix ({prefix}) must match the beginning of each filename ({basename})"
+ rank = int(basename[len(prefix) :])
+ host_name = f"host_rank{rank}"
+
+ with open(filename, "rb") as infile:
+ dump = pickle.load(infile)
+
+ entries = dump["entries"]
+ version = dump["version"]
+ pg_config = dump["pg_config"]
+
+ return {
+ "host_name": host_name,
+ "rank": rank,
+ "entries": entries,
+ "version": version,
+ "pg_config": pg_config,
+ }
+
+
+def read_dir(prefix: str, folder: str) -> Dict[Any, Any]: # TODO; fix types
+ import gc
+ import time
+
+ gc.disable()
+ details = {}
+ t0 = time.time()
+ for root, _, files in os.walk(folder):
+ for f in files:
+ ta = time.time()
+ details[f] = read_dump(prefix, os.path.join(root, f))
+ tb = time.time()
+ # print(f"read file {f} in {tb - ta}s")
+ print(f"loaded {len(files)} files in {tb - t0}s")
+ return details
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument("-d", "--dir", help="Directory with flight recorder dumps")
+ parser.add_argument("-o", "--output", default=None)
+ parser.add_argument(
+ "-p",
+ "--prefix",
+ help="prefix to strip such that rank can be extracted",
+ default="rank_",
+ )
+ parser.add_argument("-j", "--just_print_entries", action="store_true")
+ parser.add_argument("-v", "--verbose", action="store_true")
+ args = parser.parse_args()
+
+ details = read_dir(args.prefix, args.dir)
+ db = build_db(details, args)
+ if args.output:
+ with open(args.output, "wb") as f:
+ pickle.dump((types, db), f)
+
+
+if __name__ == "__main__":
+ main()