|  |  | 
|  | import ast | 
|  | import inspect | 
|  | import sys | 
|  | import textwrap | 
|  | import torch | 
|  | import warnings | 
|  |  | 
|  | class AttributeTypeIsSupportedChecker(ast.NodeVisitor): | 
|  | """ | 
|  | Checks the ``__init__`` method of a given ``nn.Module`` to ensure | 
|  | that all instance-level attributes can be properly initialized. | 
|  |  | 
|  | Specifically, we do type inference based on attribute values...even | 
|  | if the attribute in question has already been typed using | 
|  | Python3-style annotations or ``torch.jit.annotate``. This means that | 
|  | setting an instance-level attribute to ``[]`` (for ``List``), | 
|  | ``{}`` for ``Dict``), or ``None`` (for ``Optional``) isn't enough | 
|  | information for us to properly initialize that attribute. | 
|  |  | 
|  | An object of this class can walk a given ``nn.Module``'s AST and | 
|  | determine if it meets our requirements or not. | 
|  |  | 
|  | Known limitations | 
|  | 1. We can only check the AST nodes for certain constructs; we can't | 
|  | ``eval`` arbitrary expressions. This means that function calls, | 
|  | class instantiations, and complex expressions that resolve to one of | 
|  | the "empty" values specified above will NOT be flagged as | 
|  | problematic. | 
|  | 2. We match on string literals, so if the user decides to use a | 
|  | non-standard import (e.g. `from typing import List as foo`), we | 
|  | won't catch it. | 
|  |  | 
|  | Example: | 
|  |  | 
|  | .. code-block:: python | 
|  |  | 
|  | class M(torch.nn.Module): | 
|  | def fn(self): | 
|  | return [] | 
|  |  | 
|  | def __init__(self): | 
|  | super().__init__() | 
|  | self.x: List[int] = [] | 
|  |  | 
|  | def forward(self, x: List[int]): | 
|  | self.x = x | 
|  | return 1 | 
|  |  | 
|  | The above code will pass the ``AttributeTypeIsSupportedChecker`` | 
|  | check since we have a function call in ``__init__``. However, | 
|  | it will still fail later with the ``RuntimeError`` "Tried to set | 
|  | nonexistent attribute: x. Did you forget to initialize it in | 
|  | __init__()?". | 
|  |  | 
|  | Args: | 
|  | nn_module - The instance of ``torch.nn.Module`` whose | 
|  | ``__init__`` method we wish to check | 
|  | """ | 
|  |  | 
|  | def check(self, nn_module: torch.nn.Module) -> None: | 
|  | # Check if we have a Python version <3.8 | 
|  | self.using_deprecated_ast: bool = sys.version_info < (3, 8) | 
|  |  | 
|  | source_lines = inspect.getsource(nn_module.__class__.__init__) | 
|  |  | 
|  | # Ignore comments no matter the indentation | 
|  | def is_useless_comment(line): | 
|  | line = line.strip() | 
|  | return line.startswith("#") and not line.startswith("# type:") | 
|  | source_lines = "\n".join([l for l in source_lines.split("\n") if not is_useless_comment(l)]) | 
|  |  | 
|  | # This AST only contains the `__init__` method of the nn.Module | 
|  | init_ast = ast.parse(textwrap.dedent(source_lines)) | 
|  |  | 
|  | # Get items annotated in the class body | 
|  | self.class_level_annotations = list(nn_module.__annotations__.keys()) | 
|  |  | 
|  | # Flag for later | 
|  | self.visiting_class_level_ann = False | 
|  |  | 
|  | self.visit(init_ast) | 
|  |  | 
|  | def _is_empty_container(self, node: ast.AST, ann_type: str) -> bool: | 
|  | if ann_type == "List": | 
|  | # Assigning `[]` to a `List` type gives you a Node where | 
|  | # value=List(elts=[], ctx=Load()) | 
|  | if not isinstance(node, ast.List): | 
|  | return False | 
|  | if node.elts: | 
|  | return False | 
|  | elif ann_type == "Dict": | 
|  | # Assigning `{}` to a `Dict` type gives you a Node where | 
|  | # value=Dict(keys=[], values=[]) | 
|  | if not isinstance(node, ast.Dict): | 
|  | return False | 
|  | if node.keys: | 
|  | return False | 
|  | elif ann_type == "Optional": | 
|  | # Assigning `None` to an `Optional` type gives you a | 
|  | # Node where value=Constant(value=None, kind=None) | 
|  | # or, in Python <3.8, value=NameConstant(value=None) | 
|  | if (not self.using_deprecated_ast | 
|  | and not isinstance(node, ast.Constant)): | 
|  | return False | 
|  | if (self.using_deprecated_ast | 
|  | and not isinstance(node, ast.NameConstant)): | 
|  | return False | 
|  | if node.value:  # type: ignore[attr-defined] | 
|  | return False | 
|  |  | 
|  | return True | 
|  |  | 
|  | def visit_Assign(self, node): | 
|  | """ | 
|  | If we're visiting a Call Node (the right-hand side of an | 
|  | assignment statement), we won't be able to check the variable | 
|  | that we're assigning to (the left-hand side of an assignment). | 
|  | Because of this, we need to store this state in visitAssign. | 
|  | (Luckily, we only have to do this if we're assigning to a Call | 
|  | Node, i.e. ``torch.jit.annotate``. If we're using normal Python | 
|  | annotations, we'll be visiting an AnnAssign Node, which has its | 
|  | target built in.) | 
|  | """ | 
|  | try: | 
|  | if (isinstance(node.value, ast.Call) | 
|  | and node.targets[0].attr in self.class_level_annotations): | 
|  | self.visiting_class_level_ann = True | 
|  | except AttributeError: | 
|  | return | 
|  | self.generic_visit(node) | 
|  | self.visiting_class_level_ann = False | 
|  |  | 
|  | def visit_AnnAssign(self, node): | 
|  | """ | 
|  | Visit an AnnAssign node in an ``nn.Module``'s ``__init__`` | 
|  | method and see if it conforms to our attribute annotation rules. | 
|  | """ | 
|  | # If we have a local variable | 
|  | try: | 
|  | if node.target.value.id != "self": | 
|  | return | 
|  | except AttributeError: | 
|  | return | 
|  |  | 
|  | # If we have an attribute that's already been annotated at the | 
|  | # class level | 
|  | if node.target.attr in self.class_level_annotations: | 
|  | return | 
|  |  | 
|  | # TODO @ansley: add `Union` once landed | 
|  |  | 
|  | # NB: Even though `Tuple` is a "container", we don't want to | 
|  | # check for it here. `Tuple` functions as an type with an | 
|  | # "infinite" number of subtypes, in the sense that you can have | 
|  | # `Tuple[())]`, `Tuple[T1]`, `Tuple[T2]`, `Tuple[T1, T2]`, | 
|  | # `Tuple[T2, T1]` and so on, and none of these subtypes can be | 
|  | # used in place of the other. Therefore, assigning an empty | 
|  | # tuple in `__init__` CORRECTLY means that that variable | 
|  | # cannot be reassigned later to a non-empty tuple. Same | 
|  | # deal with `NamedTuple` | 
|  |  | 
|  | containers = {"List", "Dict", "Optional"} | 
|  |  | 
|  | # If we're not evaluating one of the specified problem types | 
|  | try: | 
|  | if node.annotation.value.id not in containers: | 
|  | return | 
|  | except AttributeError: | 
|  | # To evaluate a base type (`str`, `int`, etc.), we would | 
|  | # have needed to get the name through `node.annotation.id` | 
|  | # instead of `node.annotation.value.id`. Seems that we're | 
|  | # not evaluating one of our "containers" | 
|  | return | 
|  |  | 
|  | # Check if the assigned variable is empty | 
|  | ann_type = node.annotation.value.id | 
|  | if not self._is_empty_container(node.value, ann_type): | 
|  | return | 
|  |  | 
|  | warnings.warn("The TorchScript type system doesn't support " | 
|  | "instance-level annotations on empty non-base " | 
|  | "types in `__init__`. Instead, either 1) use a " | 
|  | "type annotation in the class body, or 2) wrap " | 
|  | "the type in `torch.jit.Attribute`.") | 
|  |  | 
|  | def visit_Call(self, node): | 
|  | """ | 
|  | Visit a Call node in an ``nn.Module``'s ``__init__`` | 
|  | method and determine if it's ``torch.jit.annotate``. If so, | 
|  | see if it conforms to our attribute annotation rules. | 
|  | """ | 
|  | # If we have an attribute that's already been annotated at the | 
|  | # class level | 
|  | if self.visiting_class_level_ann: | 
|  | return | 
|  |  | 
|  | # If this isn't a call to `torch.jit.annotate` | 
|  | try: | 
|  | if (node.func.value.value.id != "torch" | 
|  | or node.func.value.attr != "jit" | 
|  | or node.func.attr != "annotate"): | 
|  | self.generic_visit(node) | 
|  | elif (node.func.value.value.id != "jit" | 
|  | or node.func.value.attr != "annotate"): | 
|  | self.generic_visit(node) | 
|  | except AttributeError: | 
|  | # Looks like we didn't even have the right node structure | 
|  | # to check for `torch.jit.annotate` in the first place | 
|  | self.generic_visit(node) | 
|  |  | 
|  | # Invariant: we have a `torch.jit.annotate` or a | 
|  | # `torch.annotate` call | 
|  |  | 
|  | # A Call Node for `torch.jit.annotate` should have an `args` | 
|  | # list of length 2 where args[0] represents the annotation and | 
|  | # args[1] represents the actual value | 
|  | if len(node.args) != 2: | 
|  | return | 
|  |  | 
|  | if not isinstance(node.args[0], ast.Subscript): | 
|  | return | 
|  |  | 
|  | # See notes in `visit_AnnAssign` r.e. containers | 
|  |  | 
|  | containers = {"List", "Dict", "Optional"} | 
|  |  | 
|  | try: | 
|  | ann_type = node.args[0].value.id    # type: ignore[attr-defined] | 
|  | except AttributeError: | 
|  | return | 
|  |  | 
|  | if ann_type not in containers: | 
|  | return | 
|  |  | 
|  | # Check if the assigned variable is empty | 
|  | if not self._is_empty_container(node.args[1], ann_type): | 
|  | return | 
|  |  | 
|  | warnings.warn("The TorchScript type system doesn't support " | 
|  | "instance-level annotations on empty non-base " | 
|  | "types in `__init__`. Instead, either 1) use a " | 
|  | "type annotation in the class body, or 2) wrap " | 
|  | "the type in `torch.jit.Attribute`.") |