Initial commit of flight recorder trace (#130764)

Summary:
`fr_trace.py` is used to analyze flight recorder dump files.
This script was taken from @wconstab and @zdevito.
Only minor changes made were to make the linter happy and add a few odd new fields that I added in version `2.2` of the collector portions.

Test Plan:
Tested manually on some flight recorder data and it seems to run.

TODO:
Address 15 odd `#type: ignore` that I put in there to make the linter happy for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130764
Approved by: https://github.com/fduwjj
diff --git a/tools/flight_recorder/fr_trace.py b/tools/flight_recorder/fr_trace.py
new file mode 100644
index 0000000..dc05ea6
--- /dev/null
+++ b/tools/flight_recorder/fr_trace.py
@@ -0,0 +1,828 @@
+#!/usr/bin/env python3
+"""Flight Recorder Trace Analyzer
+
+This script primarily merges data from individual flight recorder buffers from individual ranks in a
+PyTorch Distributed program into a flattened database format that can be used for further analysis.
+
+However as part of the merging process, it is necessary to perform some analysis in order to match operators
+on one rank with corresponding operators on other ranks and register them as one 'collective' entry.  During this
+process, a significant amount of useful information can already be extracted such as where the first mismatch occurs
+in cases of desync (when not all ranks issue a compatible collective in a particular process group).
+
+
+Not Yet Implemented
+- TODO- tracebacks aren't implemented
+
+Known Issues
+- Flight Recorder buffer sequence_id information is not sufficient to match collectives and coalseced collectives
+  unless we have the trace data from the beginning of the program.  To enable confident analysis of trace buffers that
+  do not start from zero (and to simplify the script's matching logic) we need to add more information to the recorder.
+- Currently, the script omits checking the 'status' of collectives.  We can look for the first 'non completed'
+  collective easily enough and report that.
+
+Usage
+python fr_trace.py -d <dump dir containing trace files> [-o <output file>]
+
+- Omitting the optional output file will still yield analysis information to stdout
+- the output file is a pickle of the flat DB, which may change in format in the future.
+"""
+
+import argparse
+import os
+import pickle
+import sys
+from typing import (  # type: ignore[attr-defined]
+    _eval_type,
+    Any,
+    Dict,
+    Generic,
+    List,
+    NamedTuple,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+)
+
+import tabulate  # type: ignore[import-untyped]
+
+
+T = TypeVar("T", bound=NamedTuple)
+
+
+class Ref(Generic[T]):
+    pass
+
+
+class TypeInfo(NamedTuple):
+    name: str
+    fields: List[Tuple[str, Type]]  # type: ignore[type-arg]
+
+    @classmethod
+    def from_type(cls, c: T) -> "TypeInfo":
+        if hasattr(c, "__name__"):
+            name = c.__name__
+        else:
+            name = str(c)
+        return cls(
+            name,
+            [(f, _eval_type(c.__annotations__[f], globals(), {})) for f in c._fields],
+        )
+
+
+"""
+Schema for flat DB
+
+TODO schemas not yet implemented
+# threads as recorded at termination of process
+Threads
+    id: int
+    traceback_id: int
+    process_id: int
+
+Process:
+    id: int # Same as world groups RANK
+    pid: int
+    hostname: str
+
+NCCLOp:
+    # nccl op implementation details (sends/recv)
+    id: int
+    nccl_call_id: int
+
+"""
+
+
+class Group(NamedTuple):
+    id: int
+    desc: str
+    size: int
+
+
+class Membership(NamedTuple):
+    group_id: Ref[Group]
+    global_rank: int
+
+
+class Traceback(NamedTuple):
+    id: int
+    frames: str
+
+
+class Collective(NamedTuple):
+    id: int
+    group_id: Ref[Group]
+
+
+class NCCLCall(NamedTuple):
+    id: int
+    collective_id: Ref[Collective]
+    group_id: Ref[Group]
+    global_rank: int  # technically Ref[Process] once we have it
+    traceback_id: Ref[Traceback]
+    collective_type: str
+    sizes: List[List[int]]
+
+
+class Database(NamedTuple):
+    groups: List[Group]
+    memberships: List[Membership]
+    tracebacks: List[Traceback]
+    collectives: List[Collective]
+    ncclcalls: List[NCCLCall]
+
+
+types = [
+    TypeInfo.from_type(t)  # type: ignore[type-var]
+    for t in globals().values()
+    if (
+        isinstance(t, type)
+        and issubclass(t, tuple)
+        and hasattr(t, "_fields")
+        and t is not TypeInfo
+    )
+]
+
+"""
+Stacktrace cache
+TODO
+"""
+
+
+"""
+Collective Matching logic
+"""
+COLLECTIVES = {
+    "broadcast",
+    "all_gather",
+    "all_reduce",
+    "_all_gather_base",
+    "all_gather_into_tensor_coalesced",
+    "reduce_scatter_tensor_coalesced",
+    "_reduce_scatter_base",
+    "gather",
+    "scatter",
+}
+
+P2P = {
+    "send",
+    "recv",
+}
+
+
+class Op:
+    """Parses relevant info about operation out of 'event' dict
+
+    examples of supported `profiling_name`s:
+        nccl:broadcast
+        nccl:send 1->2
+        nccl:recv 3<-0
+    """
+
+    def __init__(self, event: Dict[Any, Any], memberships: Dict[str, List[Membership]]):
+        profiling_name = event["profiling_name"]
+        nccl, name = profiling_name.split(":")
+        assert nccl == "nccl", f"name formatting error? {nccl} != 'nccl'"
+        parts = name.split(" ")
+        type = parts[0]
+        meta = parts[1] if len(parts) == 2 else None
+        self.state = event["state"]
+
+        self.pg_name, _ = event["process_group"]
+
+        assert type in COLLECTIVES | P2P | {
+            "coalesced"
+        }, f"{type} is not a supported operation"
+        self.type = type
+        if type == "send":
+            s, d = meta.split("->")
+            self._src, self._dst = int(s), int(d)
+        elif type == "recv":
+            d, s = meta.split("<-")
+            self._dst, self._src = int(d), int(s)
+        else:
+            self._src, self._dst = -1, -1
+        pg_name, pg_desc = event["process_group"]
+        self._init_global_src_dst(memberships[pg_name])
+
+        if type in P2P | COLLECTIVES:
+            self.input_sizes = event["input_sizes"]
+            self.output_sizes = event["output_sizes"]
+        else:
+            self.input_sizes, self.output_sizes = None, None
+        self.collective_seq_id = event["collective_seq_id"]
+        self.p2p_seq_id = event["p2p_seq_id"]
+
+    def _init_global_src_dst(self, pg_ranks: List[Membership]) -> None:
+        pg_ranks = sorted(pg_ranks)
+        self._src_g = pg_ranks[self._src] if self._src is not None else None
+        self._dst_g = pg_ranks[self._dst] if self._dst is not None else None
+
+    @property
+    def src(self) -> int:
+        assert self.type in P2P, "can't get src of non-p2p op"
+        return self._src
+
+    @property
+    def dst(self) -> int:
+        assert self.type in P2P, "can't get dst of non-p2p op"
+        return self._dst
+
+    def __repr__(self) -> str:
+        if self.type in P2P:
+            return (
+                f"{self.type}(s={self._src_g} d={self._dst_g}, sz={self.input_sizes})"
+            )
+        return f"{self.type}(input_sizes={self.input_sizes}, {self.state})"
+
+    def match(self, other) -> bool:  # type: ignore[no-untyped-def]
+        # TODO: I think this can validly not match,
+        # e.g. if one PG was used for p2p ops between only some of the peers?
+        # if self.seq_id != other.seq_id:
+        # return False
+
+        if self.type == "send":
+            return bool(
+                other.type == "recv"
+                and self.src == other.src
+                and self.dst == other.dst
+                and self.input_sizes == other.output_sizes
+            )
+        elif self.type == "recv":
+            return bool(
+                other.type == "send"
+                and self.src == other.src
+                and self.dst == other.dst
+                and self.output_sizes == other.input_sizes
+            )
+        elif self.type in COLLECTIVES:
+            return bool(
+                self.type == other.type and self.input_sizes == other.input_sizes
+            )
+            # TODO(whc) - output sizes dont have to match for e.g. gather, not sure if they ever have to match?
+            # and self.output_sizes == other.output_sizes)
+        elif self.type == "coalesced":
+            return bool(other.type == "coalesced")
+        return False
+
+
+def match_one_event(
+    event_a: Dict[Any, Any],
+    event_b: Dict[Any, Any],
+    memberships: Dict[str, List[Membership]],
+) -> bool:
+    op_a = Op(event_a, memberships)
+    op_b = Op(event_b, memberships)
+    return op_a.match(op_b)
+
+
+def match_coalesced_groups(
+    all_rank_events: Dict[Any, Any],
+    group_size: int,
+    groups: Dict[str, Group],
+    memberships: Dict[str, List[Membership]],
+) -> bool:
+    """
+    all_rank_events: {
+        rank: [
+            (idx, event_dict)
+        ]
+    }
+
+    Note: it is possible for event dicts in a coalesced group to be asymmetric.
+        e.g. the following events lists form a valid coalescing group
+             events0 [send:1]
+             events1 [recv:0, send:2]
+             events2 [recv:1]
+
+    Rule 1: all ops should find a match
+    Rule 2: relative ordering of sends and recvs in one event list can be arbitrary
+        e.g.
+        events1 [recv:0, send:2]  —> okay
+        events1 [send:2, recv:0] —> also okay
+    Rule 3: sends to the same dest or recvs from the src should be in a consistent order
+        e.g.
+        rank0 [send:1 (100B), send:1 (1000B)]
+        rank1 [recv:0 (1000B), recv:0 (100B)]   —> not okay
+    """
+    all_ops = {
+        rank: [Op(e, memberships) for i, e in all_rank_events[rank]]
+        for rank in all_rank_events
+    }
+
+    def visualize_ops(match: bool) -> None:
+        all_ops = {
+            rank: [Op(e, memberships) for i, e in all_rank_events[rank]]
+            for rank in all_rank_events
+        }
+
+        i = 0
+        row = []
+        progress = True
+        table = []
+        while progress:
+            progress = False
+            for r in all_ops:
+                if len(all_ops[r]) > i:
+                    _, event = all_rank_events[r][i]
+                    row.append(Op(event, memberships))
+                    progress = True
+                else:
+                    row.append(None)  # type: ignore[arg-type]
+            table.append(row)
+            row = []
+            i += 1
+        title = "Match" if match else "MISMATCH"
+        print(f"{title}\n", tabulate(table))  # type: ignore[operator]
+
+    # TODO can't verify seq_id bc there might have been valid seq deltas between ranks even within a pg.
+    for op_list in all_ops.values():
+        if not op_list:
+            # print("TODO- not sure if its valid for only some ranks in a PG to participate in a coalesced op?")
+            return False
+        assert op_list[-1].type == "coalesced"
+        op_list.pop(-1)
+
+    while all_ops:
+        first_rank = next(iter(all_ops))
+        my_ops = all_ops[first_rank]
+
+        if len(all_ops[first_rank]) == 0:
+            all_ops.pop(first_rank)
+            continue
+
+        # lets match the first collective! we need to know which ranks are involved, and ensure that this same
+        # collective is also the first one on those ranks within that group
+        op = my_ops[0]
+        match_idx = -1
+        if op.type in P2P:
+            dst_global_rank = sorted(memberships[op.pg_name])[op.dst]
+            peer_ops = all_ops[dst_global_rank]
+            for i, other in enumerate(peer_ops):
+                if op.match(other):
+                    match_idx = i
+                    break
+                elif op.dst == other.src:
+                    # Rule 3
+                    break
+                else:
+                    # Rule 1
+                    continue
+        else:
+            raise NotImplementedError("coalesced collective ops")
+        if match_idx >= 0:
+            my_ops.pop(0)
+            peer_ops.pop(match_idx)
+        else:
+            visualize_ops(False)
+            return False
+
+    visualize_ops(True)
+    return True
+
+
+"""
+Flat DB builder
+"""
+
+
+def build_groups_memberships(
+    pg_config: Any,
+) -> Tuple[List[Group], Dict[Any, Group], List[Membership], Dict[str, Any]]:
+    """
+    pg_config: {
+        global_rank: {
+            (pg_id, desc, ranks)
+        }
+    }
+
+    `pg_id` is a system generated id, but depending on the mode of PG creation it could be a globally incrementing int
+          or a hash of the ranks.  See `_process_group_name` in distributed_c10d.py.
+    `desc` is provided by the user (optionally) and should be 'meaningful' (e.g. TP/PP/DP group)
+    `ranks` is a list of the 'global ranks' that are members of the PG.
+
+    (pg_id, desc, ranks) tuples are appended lazily to the flight buffer when `getNCCLComm` is called on a PG and
+    the `enabled_` flag is true for that PG.
+        - the order of calling (init_process_group, new_group, etc) does not affect the order of the tuples in the list
+
+    Returns: a groups table and a membership table, where each row is a Group or Membership namedtuple
+    """
+    # flat lists for return
+    groups = []
+    memberships = []
+
+    # dicts for faster cross-rank validation
+    _groups = {}
+    _memberships = {}
+    for global_rank in pg_config:
+        for pg_id in pg_config[global_rank]:
+            desc = pg_config[global_rank][pg_id]["desc"]
+            ranks = pg_config[global_rank][pg_id]["ranks"]
+            if isinstance(ranks, str):
+                # TODO Bug in FR data format? ranks is '[0, 1,...]'
+                ranks = eval(ranks)
+
+            if pg_id not in _groups:
+                groups.append(Group(id=pg_id, desc=desc, size=len(ranks)))
+                for rank in ranks:
+                    memberships.append(Membership(group_id=pg_id, global_rank=rank))
+                _groups[pg_id] = groups[-1]
+                _memberships[pg_id] = set(ranks)
+            else:
+                # validation across ranks
+                assert (
+                    _groups[pg_id].desc == desc
+                ), f"mismatch in desc {_groups[pg_id].desc} vs {desc}"
+                assert _memberships[pg_id] == set(
+                    ranks
+                ), f"mismatch in membership {_memberships[pg_id]} vs {set(ranks)}"
+    return groups, _groups, memberships, _memberships
+
+
+def build_nccl_call(
+    entry: Dict[Any, Any],
+    id: int,
+    collective_id: Any,
+    group_id: int,
+    global_rank: Any,
+) -> NCCLCall:
+    return NCCLCall(
+        id=id,
+        collective_id=collective_id,
+        group_id=group_id,  # type: ignore[arg-type]
+        global_rank=global_rank,
+        traceback_id=0,  # type: ignore[arg-type]
+        collective_type=entry["profiling_name"],
+        sizes=entry["input_sizes"],
+    )
+
+
+def find_coalesced_group(
+    pg_name: str, entries: List[Dict[str, Any]]
+) -> List[Tuple[int, Dict[str, Any]]]:
+    """Given a list of entries, if the collective_seq_id of the first entry matches that of subsequent ones,
+    build an return a list of entries terminating in a 'coalesced' op entry all sharing a collective_seq_id
+    TODO: handle p2p_seq_id v/s collective_seq_id separately here.
+    """
+    found = []
+    collective_seq_id = None
+    for i, e in enumerate(entries):
+        if e["process_group"][0] != pg_name:
+            continue
+        elif collective_seq_id is None:
+            collective_seq_id = e["collective_seq_id"]
+            found.append((i, e))
+        elif e["collective_seq_id"] == collective_seq_id:
+            found.append((i, e))
+        else:
+            break
+
+    if len(found) > 1:
+        assert found[-1][1]["profiling_name"] == "nccl:coalesced"
+        return found
+    return []
+
+
+def just_print_entries(
+    all_entries: Dict[int, List[Dict[str, Any]]],
+    _groups: Dict[str, Group],
+    _memberships: Dict[str, List[Membership]],
+) -> None:
+    rows = []
+    ranks = sorted(all_entries.keys())
+    headers = [f"Rank {rank}" for rank in ranks]
+    progress = True
+    while progress:
+        progress = False
+        row = []
+        for rank in ranks:
+            if len(all_entries[rank]) == 0:
+                row.append("")
+            else:
+                entry = all_entries[rank].pop(0)
+                row.append(str(Op(entry, _memberships)))
+                progress = True
+        if progress:
+            rows.append(row)
+
+    print(tabulate(rows, headers=headers))  # type: ignore[operator]
+
+
+def build_collectives(
+    all_entries: Dict[int, List[Dict[str, Any]]],
+    _groups: Dict[str, Group],
+    _memberships: Dict[str, List[Membership]],
+) -> Tuple[List[Traceback], List[Collective], List[NCCLCall]]:
+    """
+    groups, memberships are the non-flat dicts that are indexable
+    all_entries is a raw dict from the original dumps:
+
+    all_entries: {
+        global_rank: [
+            {
+                record_id: ordered id of the event in the trace buffer
+                pg_id: ProcessGroupNCCL::uid_
+                    *note: `pg_id` corresponds to nothing in groups table
+                process_group: (pg_name, desc)
+                    *note: `pg_name`, `desc` corresponds to `pg_id`, `desc` in groups table
+                collective_seq_id: ordered id for collective operations and coalesced group operations
+                p2p_seq_id: ordered id for point-to-point operations
+                op_id: ordered id including individual ops inside coalescing group
+                profiling_name: descriptive name of the operation
+                'time_created_ns',
+                'input_sizes',
+                'output_sizes',
+                'state',
+                'time_discovered_started_ns',
+                'time_discovered_completed_ns',
+                'retired',
+                'frames',
+            }
+        ]
+    }
+    """
+    tracebacks: List[Traceback] = []
+
+    collectives: List[Collective] = []
+    nccl_calls: List[NCCLCall] = []
+
+    # once we find one mismatch, we stop pairing up collectives since the pairing is possibly incorrect
+    # instead, just record the remaining ops as NCCLCalls
+    mismatch = {_groups[g].id: 0 for g in _groups}
+    MISMATCH_TAIL = 10
+    """
+    - it doesn't matter what order I put collectives/ncclops into their table. we can later on re-sort it by start time
+    - there could be multiple options for the "first" collective to pair up (rank 0,1 might do a bcast while rank 2,3 do a bcast)
+    - within a group, the first collective must be the same on all ranks in the group, then it can be marked as a
+    collective and removed
+    """
+    while all_entries:
+        # we greedily match collectives, starting arbitrarily with the trace from the first rank
+        # later, if we exhaust the first rank, we continue with the next 'first rank'
+        rank_iter = iter(all_entries)
+        first_rank = next(rank_iter)
+        other_ranks = list(rank_iter)
+
+        if len(all_entries[first_rank]) == 0:
+            all_entries.pop(first_rank)
+            continue
+
+        # lets match the first collective! we need to know which ranks are involved, and ensure that this same
+        # collective is also the first one on those ranks within that group
+        entries = all_entries[first_rank]
+        pg_name, desc = entries[0]["process_group"]
+        profiling_name = entries[0]["profiling_name"]
+        expected_ranks = set(_memberships[pg_name])
+        found_ranks = {first_rank}
+        found_idx = {}
+
+        if find_coalesced_group(pg_name, entries):
+            expected_ranks.add(first_rank)
+            done_ranks = set()
+            all_coalesced_entries = {}
+            while expected_ranks:
+                curr = expected_ranks.pop()
+                done_ranks.add(curr)
+                grp = (
+                    find_coalesced_group(pg_name, all_entries[curr])  # type: ignore[index]
+                    if curr in all_entries  # type: ignore[comparison-overlap]
+                    else []
+                )
+                all_coalesced_entries[curr] = grp
+                for index, entry in grp:
+                    op = Op(entry, _memberships)
+                    peer = None
+                    if op.type == "send":
+                        assert op._src_g == curr, (op._src_g, curr)
+                        peer = op._dst_g
+                    elif op.type == "recv":
+                        assert op._dst_g == curr, (op._dst_g, curr)
+                        peer = op._src_g
+                    if peer and peer not in done_ranks:
+                        expected_ranks.add(peer)
+
+            match = match_coalesced_groups(
+                all_coalesced_entries,
+                group_size=_groups[pg_name].size,
+                groups=_groups,
+                memberships=_memberships,
+            )
+
+            if match and mismatch[pg_name] == 0:
+                collectives.append(Collective(id=len(collectives), group_id=pg_name))
+            else:
+                mismatch[pg_name] += 1
+
+            for r in all_coalesced_entries:
+                reversed_calls = []
+                for i, _ in reversed(all_coalesced_entries[r]):
+                    reversed_calls.append(
+                        build_nccl_call(
+                            all_entries[r].pop(i),  # type: ignore[index]
+                            id=len(nccl_calls),
+                            collective_id=collectives[-1].id if match else None,
+                            group_id=pg_name,
+                            global_rank=r,
+                        )
+                    )
+                nccl_calls.extend(reversed(reversed_calls))
+
+        else:
+            for o in expected_ranks.intersection(set(other_ranks)):
+                for i, e in enumerate(all_entries[o]):  # type: ignore[index]
+                    # step over ops from other PGs
+                    if e["process_group"] == (pg_name, desc):
+                        if (
+                            match_one_event(entries[0], e, _memberships)
+                            and mismatch[pg_name] == 0
+                        ):
+                            found_ranks.add(o)
+                            found_idx[o] = i
+                        else:
+                            # we found a mismatch. what do we do with that?
+                            mismatch[pg_name] += 1
+                            print(
+                                f"Mismatched collective on rank {o} for group {pg_name}:{desc} collective {profiling_name}"
+                            )
+                        break
+
+            # at this point there are 3 possibilities
+            # 1. we found a match on all the ranks that are members of the group
+            #  -> we create a Collective and remove the individual entries from their original lists
+            if found_ranks == expected_ranks and mismatch[pg_name] == 0:
+                collectives.append(Collective(id=len(collectives), group_id=pg_name))
+                for r in found_ranks:
+                    i = found_idx[r] if r != first_rank else 0
+                    nccl_calls.append(
+                        build_nccl_call(
+                            all_entries[r].pop(i),  # type: ignore[index]
+                            id=len(nccl_calls),
+                            collective_id=collectives[-1].id,
+                            group_id=pg_name,
+                            global_rank=r,
+                        )
+                    )
+
+            # 2. we found a partial match but some ranks are missing
+            # 3. we found no match
+            #  -> since its not a complete collective, no entry goes into collectives but we still record a nccl call
+            #     TODO should there be a way to mark 'mismatches'?
+            else:
+                print("appending a non-matching collective")
+                nccl_calls.append(
+                    build_nccl_call(
+                        all_entries[first_rank].pop(0),
+                        id=len(nccl_calls),
+                        collective_id=None,
+                        group_id=pg_name,
+                        global_rank=r,
+                    )
+                )
+
+        if mismatch[pg_name] > MISMATCH_TAIL:
+            print(f"Too many mismatches for process_group {pg_name}:{desc}, aborting")
+            sys.exit(-1)
+
+    return tracebacks, collectives, nccl_calls
+
+
+def check_no_missing_dump_files(
+    entries: Dict[str, Any], memberships: List[Membership]
+) -> None:
+    all_ranks = set()
+    for membership in memberships:
+        all_ranks.add(str(membership.global_rank))
+    dumps_ranks = set(entries.keys())
+    assert (
+        dumps_ranks == all_ranks
+    ), f"Missing dump files from ranks {all_ranks - dumps_ranks}"
+
+
+def check_version(versions: Dict[str, Any]) -> None:
+    for rank, version in versions.items():  # noqa: PERF102
+        major, minor = map(int, version.split("."))
+        # assert major == 2, f"Rank {rank} unsupported version {version}"
+        # assert minor >= 0, f"Rank {rank} unsupported version {version}"
+
+
+def check_trace_from_beginning(entries: Dict[str, Any]) -> bool:
+    for rank in entries:
+        first_record_id = entries[rank][0]["record_id"]
+        # TODO add more sequence information such that analysis can proceed even without complete buffer
+
+        # assert first_record_id == 0, f"Rank {rank} trace does not start at time 0 (first record is {first_record_id}."
+        if first_record_id != 0:
+            print(
+                f"Rank {rank} trace does not start at time 0 (first record is {first_record_id}."
+            )
+            return False
+    return True
+
+
+def build_db(details: Dict[str, Dict[str, Any]], args: argparse.Namespace) -> Database:
+    # temporary state used for building database
+    entries = {}
+    pg_config = {}
+    version = {}
+    for dump in details.values():
+        rank = dump["rank"]
+        entries[rank] = dump["entries"]
+        version[rank] = dump["version"]
+        pg_config[rank] = dump["pg_config"]
+
+    check_version(version)
+    check_trace_from_beginning(entries)
+
+    # flattened database
+    groups, _groups, memberships, _memberships = build_groups_memberships(pg_config)
+    print("built groups, memberships")
+
+    check_no_missing_dump_files(entries, memberships)
+
+    if args.just_print_entries:
+        just_print_entries(entries, _groups, _memberships)
+        sys.exit(0)
+
+    tracebacks, collectives, nccl_calls = build_collectives(
+        entries, _groups, _memberships
+    )
+    print("built collectives, nccl_calls")
+    if args.verbose:
+        print("Groups\n", tabulate(groups, headers=Group._fields))  # type: ignore[operator]
+        print("Memberships\n", tabulate(memberships, headers=Membership._fields))  # type: ignore[operator]
+        print("Collectives\n", tabulate(collectives, headers=Collective._fields))  # type: ignore[operator]
+        print("NCCLCalls\n", tabulate(nccl_calls, headers=NCCLCall._fields))  # type: ignore[operator]
+    db = Database(
+        tracebacks=tracebacks,
+        collectives=collectives,
+        ncclcalls=nccl_calls,
+        groups=groups,
+        memberships=memberships,
+    )
+    return db
+
+
+def read_dump(prefix: str, filename: str) -> Dict[str, Union[str, int, List[Any]]]:
+    basename = os.path.basename(filename)
+    assert (
+        basename.find(prefix) == 0
+    ), f"args.prefix ({prefix}) must match the beginning of each filename ({basename})"
+    rank = int(basename[len(prefix) :])
+    host_name = f"host_rank{rank}"
+
+    with open(filename, "rb") as infile:
+        dump = pickle.load(infile)
+
+    entries = dump["entries"]
+    version = dump["version"]
+    pg_config = dump["pg_config"]
+
+    return {
+        "host_name": host_name,
+        "rank": rank,
+        "entries": entries,
+        "version": version,
+        "pg_config": pg_config,
+    }
+
+
+def read_dir(prefix: str, folder: str) -> Dict[Any, Any]:  # TODO; fix types
+    import gc
+    import time
+
+    gc.disable()
+    details = {}
+    t0 = time.time()
+    for root, _, files in os.walk(folder):
+        for f in files:
+            ta = time.time()
+            details[f] = read_dump(prefix, os.path.join(root, f))
+            tb = time.time()
+            # print(f"read file {f} in {tb - ta}s")
+    print(f"loaded {len(files)} files in {tb - t0}s")
+    return details
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(description=__doc__)
+    parser.add_argument("-d", "--dir", help="Directory with flight recorder dumps")
+    parser.add_argument("-o", "--output", default=None)
+    parser.add_argument(
+        "-p",
+        "--prefix",
+        help="prefix to strip such that rank can be extracted",
+        default="rank_",
+    )
+    parser.add_argument("-j", "--just_print_entries", action="store_true")
+    parser.add_argument("-v", "--verbose", action="store_true")
+    args = parser.parse_args()
+
+    details = read_dir(args.prefix, args.dir)
+    db = build_db(details, args)
+    if args.output:
+        with open(args.output, "wb") as f:
+            pickle.dump((types, db), f)
+
+
+if __name__ == "__main__":
+    main()