|  | import argparse | 
|  | import json | 
|  | from os import path | 
|  |  | 
|  | import torch | 
|  |  | 
|  | # Import all utils so that getattr below can find them | 
|  |  | 
|  | all_submod_list = [ | 
|  | "", | 
|  | "nn", | 
|  | "nn.functional", | 
|  | "nn.init", | 
|  | "optim", | 
|  | "autograd", | 
|  | "cuda", | 
|  | "sparse", | 
|  | "distributions", | 
|  | "fft", | 
|  | "linalg", | 
|  | "jit", | 
|  | "distributed", | 
|  | "futures", | 
|  | "onnx", | 
|  | "random", | 
|  | "utils.bottleneck", | 
|  | "utils.checkpoint", | 
|  | "utils.data", | 
|  | "utils.model_zoo", | 
|  | ] | 
|  |  | 
|  |  | 
|  | def get_content(submod): | 
|  | mod = torch | 
|  | if submod: | 
|  | submod = submod.split(".") | 
|  | for name in submod: | 
|  | mod = getattr(mod, name) | 
|  | content = dir(mod) | 
|  | return content | 
|  |  | 
|  |  | 
|  | def namespace_filter(data): | 
|  | out = {d for d in data if d[0] != "_"} | 
|  | return out | 
|  |  | 
|  |  | 
|  | def run(args, submod): | 
|  | print(f"## Processing torch.{submod}") | 
|  | prev_filename = f"prev_data_{submod}.json" | 
|  | new_filename = f"new_data_{submod}.json" | 
|  |  | 
|  | if args.prev_version: | 
|  | content = get_content(submod) | 
|  | with open(prev_filename, "w") as f: | 
|  | json.dump(content, f) | 
|  | print("Data saved for previous version.") | 
|  | elif args.new_version: | 
|  | content = get_content(submod) | 
|  | with open(new_filename, "w") as f: | 
|  | json.dump(content, f) | 
|  | print("Data saved for new version.") | 
|  | else: | 
|  | assert args.compare | 
|  | if not path.exists(prev_filename): | 
|  | raise RuntimeError("Previous version data not collected") | 
|  |  | 
|  | if not path.exists(new_filename): | 
|  | raise RuntimeError("New version data not collected") | 
|  |  | 
|  | with open(prev_filename) as f: | 
|  | prev_content = set(json.load(f)) | 
|  |  | 
|  | with open(new_filename) as f: | 
|  | new_content = set(json.load(f)) | 
|  |  | 
|  | if not args.show_all: | 
|  | prev_content = namespace_filter(prev_content) | 
|  | new_content = namespace_filter(new_content) | 
|  |  | 
|  | if new_content == prev_content: | 
|  | print("Nothing changed.") | 
|  | print("") | 
|  | else: | 
|  | print("Things that were added:") | 
|  | print(new_content - prev_content) | 
|  | print("") | 
|  |  | 
|  | print("Things that were removed:") | 
|  | print(prev_content - new_content) | 
|  | print("") | 
|  |  | 
|  |  | 
|  | def main(): | 
|  | parser = argparse.ArgumentParser( | 
|  | description="Tool to check namespace content changes" | 
|  | ) | 
|  |  | 
|  | group = parser.add_mutually_exclusive_group(required=True) | 
|  | group.add_argument("--prev-version", action="store_true") | 
|  | group.add_argument("--new-version", action="store_true") | 
|  | group.add_argument("--compare", action="store_true") | 
|  |  | 
|  | group = parser.add_mutually_exclusive_group() | 
|  | group.add_argument("--submod", default="", help="part of the submodule to check") | 
|  | group.add_argument( | 
|  | "--all-submod", | 
|  | action="store_true", | 
|  | help="collects data for all main submodules", | 
|  | ) | 
|  |  | 
|  | parser.add_argument( | 
|  | "--show-all", | 
|  | action="store_true", | 
|  | help="show all the diff, not just public APIs", | 
|  | ) | 
|  |  | 
|  | args = parser.parse_args() | 
|  |  | 
|  | if args.all_submod: | 
|  | submods = all_submod_list | 
|  | else: | 
|  | submods = [args.submod] | 
|  |  | 
|  | for mod in submods: | 
|  | run(args, mod) | 
|  |  | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | main() |