| # mypy: allow-untyped-defs |
| import logging |
| import warnings |
| from typing import Any, Dict, Iterable, Optional, Tuple |
| |
| import torch |
| import torch.export |
| import torch.export._trace |
| from torch._utils_internal import log_export_usage |
| |
| log = logging.getLogger(__name__) |
| |
| __all__ = ["report_exportability"] |
| |
| |
| def _generate_inputs_for_submodules( |
| model: torch.nn.Module, |
| target_submodules: Iterable[str], |
| args: Tuple[Any, ...], |
| kwargs: Optional[Dict[str, Any]] = None, |
| ) -> Dict[str, Tuple[Any, Any]]: |
| """ |
| Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this |
| function doesn't work. |
| |
| Args: |
| model: root model. |
| inputs: inputs to the root model. |
| target_submodules: submodules that we want to generate inputs for. |
| |
| Returns: |
| A dict that maps from submodule name to its inputs. |
| """ |
| kwargs = kwargs or {} |
| |
| handles = [] |
| results = {} |
| submodule_to_names = {mod: name for name, mod in model.named_modules()} |
| |
| def pre_forward(module, module_args, module_kwargs): |
| results[submodule_to_names[module]] = (module_args, module_kwargs) |
| |
| try: |
| for name, mod in model.named_modules(): |
| if name in target_submodules: |
| handles.append( |
| mod.register_forward_pre_hook(pre_forward, with_kwargs=True) |
| ) |
| model(*args, **kwargs) |
| except Exception as e: |
| warnings.warn( |
| f"Failed to generate submodule inputs because of the following error:\n{e}" |
| ) |
| finally: |
| for h in handles: |
| h.remove() |
| return results |
| |
| |
| def report_exportability( |
| mod: torch.nn.Module, |
| args: Tuple[Any, ...], |
| kwargs: Optional[Dict[str, Any]] = None, |
| *, |
| strict: bool = True, |
| pre_dispatch: bool = False, |
| ) -> Dict[str, Optional[Exception]]: |
| """ |
| Report exportability issues for a module in one-shot. |
| |
| Args: |
| mod: root module. |
| args: args to the root module. |
| kwargs: kwargs to the root module. |
| Returns: |
| A dict that maps from submodule name to the exception that was raised when trying to export it. |
| `None` means the module is exportable without issue. |
| Sample output: |
| { |
| '': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>), |
| 'submod_1': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>), |
| 'submod_2': None |
| } |
| """ |
| |
| log_export_usage(event="export.report_exportability") |
| |
| kwargs = kwargs or {} |
| |
| all_submod_names = [name for name, _ in mod.named_modules() if name != ""] |
| submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs) |
| |
| report: Dict[str, Optional[Exception]] = {} |
| |
| def try_export(module, module_name, args, kwargs): |
| nonlocal submod_inputs, report, strict, pre_dispatch |
| |
| if args is not None or kwargs is not None: |
| try: |
| torch.export._trace._export( |
| module, |
| args, |
| kwargs, |
| strict=strict, |
| pre_dispatch=pre_dispatch, |
| ) |
| report[module_name] = None |
| log.info("Successfully exported `%s`", module_name) |
| return |
| except Exception as e: |
| short_msg = repr(e).split("\n")[0] |
| log.warning( |
| "Failed exporting `%s` with exception: %s", module_name, short_msg |
| ) |
| report[module_name] = e |
| |
| for name, submod in module.named_children(): |
| sub_module_name = name if module_name == "" else f"{module_name}.{name}" |
| |
| submod_args, submod_kwargs = submod_inputs.get( |
| sub_module_name, (None, None) |
| ) |
| |
| try_export(submod, sub_module_name, submod_args, submod_kwargs) |
| |
| return |
| |
| try_export(mod, "", args, kwargs) |
| |
| unique_issues = set() |
| for exception in report.values(): |
| if exception is not None: |
| key = repr(exception).split("\\n")[0] |
| unique_issues.add(key) |
| |
| log.warning("Found %d export issues:", len(unique_issues)) |
| for issue in unique_issues: |
| log.warning(issue) |
| |
| return report |