annotate torch.autograd.* modules (#45004)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/44638
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45004
Reviewed By: VitalyFedyunin
Differential Revision: D24113562
Pulled By: ezyang
fbshipit-source-id: a85018b7e08b2fe6cf2bc14a217eb418cb2b9de4
diff --git a/mypy.ini b/mypy.ini
index ea7bdb1..af39fd6 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -180,27 +180,6 @@
[mypy-torch.utils.hipify.hipify_python]
ignore_errors = True
-[mypy-torch.autograd._functions.tensor]
-ignore_errors = True
-
-[mypy-torch.autograd.function]
-ignore_errors = True
-
-[mypy-torch.autograd.functional]
-ignore_errors = True
-
-[mypy-torch.autograd.profiler]
-ignore_errors = True
-
-[mypy-torch.autograd.gradcheck]
-ignore_errors = True
-
-[mypy-torch.autograd.anomaly_mode]
-ignore_errors = True
-
-[mypy-torch.autograd.variable]
-ignore_errors = True
-
[mypy-torch.nn.quantized.modules.batchnorm]
ignore_errors = True
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index f1e96e3..2ad2f64 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -93,6 +93,7 @@
# Defined in torch/csrc/utils/tensor_layouts.cpp
strided : layout = ...
sparse_coo : layout = ...
+_mkldnn : layout = ...
# Defined in torch/csrc/MemoryFormat.cpp
class memory_format: ...
@@ -268,6 +269,10 @@
class Graph:
...
+# Defined in torch/csrc/jit/ir/ir.h
+class Value:
+ ...
+
# Defined in torch/aten/src/ATen/core/function_schema.h
class FunctionSchema:
...
@@ -389,6 +394,7 @@
def _vmapmode_increment_nesting() -> _int: ... # THPModule_vmapmode_increment_nesting
def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_nesting
def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython
+def _demangle(str) -> str: ... # c10::demangle
# Defined in `valgrind.h` and `callgrind.h` respecitively.
def valgrind_supported_platform() -> _bool: ... # NVALGRIND
@@ -497,6 +503,10 @@
# TODO: where
${legacy_class_hints}
+# Defined in torch/csrc/autograd/python_engine.cpp
+class _ImperativeEngine:
+ ...
+
# Defined in torch/csrc/autograd/python_variable.cpp
class _TensorBase(object):
requires_grad: _bool
diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi
index 653b705..a154fb1 100644
--- a/torch/_C/_autograd.pyi
+++ b/torch/_C/_autograd.pyi
@@ -11,10 +11,14 @@
class ProfilerConfig:
- def __init__(self, state: ProfilerState, report_input_shapes: bool, profile_memory: bool) -> None: ...
+ def __init__(
+ self, state: ProfilerState,
+ report_input_shapes: bool,
+ profile_memory: bool,
+ with_stack: bool
+ ) -> None: ...
...
-
class ProfilerEvent:
def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ...
def cpu_memory_usage(self) -> int: ...
diff --git a/torch/_C/_functions.pyi b/torch/_C/_functions.pyi
new file mode 100644
index 0000000..4ad76e4
--- /dev/null
+++ b/torch/_C/_functions.pyi
@@ -0,0 +1,12 @@
+from torch import Tensor
+from typing import AnyStr, List
+
+class UndefinedGrad:
+ def __init__(self) -> None: ...
+ def __call__(self, inputs: List[Tensor]) -> List[Tensor]: ...
+ ...
+
+class DelayedError:
+ def __init__(self, msg: AnyStr, num_inputs: int) -> None: ...
+ def __call__(self, inputs: List[Tensor]) -> List[Tensor]: ...
+ ...
\ No newline at end of file
diff --git a/torch/autograd/function.py b/torch/autograd/function.py
index 6714444..0d546ce 100644
--- a/torch/autograd/function.py
+++ b/torch/autograd/function.py
@@ -1,11 +1,12 @@
import torch
import torch._C as _C
+from torch._C import _functions
import torch.utils.hooks as hooks
from torch._six import with_metaclass
import functools
import warnings
from collections import OrderedDict
-from typing import Any
+from typing import Any, List, Optional
class _ContextMethodMixin(object):
@@ -84,7 +85,8 @@
_is_legacy = False
def apply(self, *args):
- return self._forward_cls.backward(self, *args)
+ # _forward_cls is defined by derived class
+ return self._forward_cls.backward(self, *args) # type: ignore
class FunctionMeta(type):
@@ -115,8 +117,8 @@
return super(FunctionMeta, cls).__init__(name, bases, attrs)
-
-class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)):
+# mypy doesn't understand `with_metaclass` from torch._six
+class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)): # type: ignore
r"""Records operation history and defines formulas for differentiating ops.
See the Note on extending the autograd engine for more details on how to use
@@ -227,7 +229,7 @@
if not isinstance(outputs, tuple):
outputs = (outputs,)
- err_fn = torch._C._functions.DelayedError(
+ err_fn = _functions.DelayedError(
b"trying to differentiate twice a function that was marked"
b"with @once_differentiable", len(outputs))
@@ -330,7 +332,7 @@
# unflatten a list or tuple input into a nested list/tuple structure
# specified by proto
def unflatten_helper(input, proto):
- res = []
+ res: List[Optional[torch.Tensor]] = []
if hasattr(proto, "_jit_wrap"):
return proto._jit_wrap(input)
if not isinstance(proto, (list, tuple)):
@@ -379,16 +381,16 @@
del self._to_save_nested
return result
- def backward(self, *gradients: Any) -> Any:
+ def backward(self, *gradients: Any) -> Any: # type: ignore
nested_gradients = _unflatten(gradients, self._nested_output)
- result = self.backward_extended(*nested_gradients)
+ result = self.backward_extended(*nested_gradients) # type: ignore
return tuple(_iter_None_tensors(result))
__call__ = _do_forward
- def forward(self, *args: Any) -> Any:
+ def forward(self, *args: Any) -> Any: # type: ignore
nested_tensors = _map_tensor_data(self._nested_input)
- result = self.forward_extended(*nested_tensors)
+ result = self.forward_extended(*nested_tensors) # type: ignore
del self._nested_input
self._nested_output = result
return tuple(_iter_tensors(result))
diff --git a/torch/autograd/functional.py b/torch/autograd/functional.py
index 58e780c..70961ce 100644
--- a/torch/autograd/functional.py
+++ b/torch/autograd/functional.py
@@ -1,4 +1,5 @@
import torch
+from typing import Tuple, List
# Utility functions
@@ -131,8 +132,8 @@
assert isinstance(grad_outputs, tuple)
assert len(outputs) == len(grad_outputs)
- new_outputs = tuple()
- new_grad_outputs = tuple()
+ new_outputs: Tuple[torch.Tensor, ...] = tuple()
+ new_grad_outputs: Tuple[torch.Tensor, ...] = tuple()
for out, grad_out in zip(outputs, grad_outputs):
if out is not None and out.requires_grad:
new_outputs += (out,)
@@ -153,7 +154,7 @@
if stage not in ["back", "back_trick", "double_back", "double_back_trick"]:
raise RuntimeError("Invalid stage argument '{}' to _fill_in_zeros".format(stage))
- res = tuple()
+ res: Tuple[torch.Tensor, ...] = tuple()
for i, grads_i in enumerate(grads):
if grads_i is None:
if strict:
@@ -427,10 +428,11 @@
"jacobian")
_check_requires_grad(outputs, "outputs", strict=strict)
- jacobian = tuple()
+ jacobian: Tuple[torch.Tensor, ...] = tuple()
for i, out in enumerate(outputs):
- jac_i = tuple([] for _ in range(len(inputs)))
+ # mypy complains that expression and variable have different types due to the empty list
+ jac_i: Tuple[List[torch.Tensor]] = tuple([] for _ in range(len(inputs))) # type: ignore
for j in range(out.nelement()):
vj = _autograd_grad((out.reshape(-1)[j],), inputs,
retain_graph=True, create_graph=create_graph)
diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py
index b2bea45..531bcc6 100644
--- a/torch/autograd/gradcheck.py
+++ b/torch/autograd/gradcheck.py
@@ -5,7 +5,7 @@
from torch.overrides import is_tensor_like
from itertools import product
import warnings
-from typing import Callable, Union, Optional
+from typing import Callable, Union, Optional, Iterable, List
def zero_gradients(x):
if isinstance(x, torch.Tensor):
@@ -29,15 +29,16 @@
lambda x: x is not None, (make_jacobian(elem, num_out) for elem in input)))
if not jacobians:
return None
- return type(input)(jacobians)
+ return type(input)(jacobians) # type: ignore
else:
return None
-def iter_tensors(x, only_requiring_grad=False):
+def iter_tensors(x: Union[torch.Tensor, Iterable[torch.Tensor]], only_requiring_grad: bool = False) -> Iterable[torch.Tensor]:
if is_tensor_like(x):
- if x.requires_grad or not only_requiring_grad:
- yield x
+ # mypy doesn't narrow type of `x` to torch.Tensor
+ if x.requires_grad or not only_requiring_grad: # type: ignore
+ yield x # type: ignore
elif isinstance(x, container_abcs.Iterable) and not isinstance(x, str):
for elem in x:
for result in iter_tensors(elem, only_requiring_grad):
@@ -137,7 +138,7 @@
indices = x_indices[i].tolist() + list(x_idx)
d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size)))
update_jacobians(x_value, x_idx, d_tensor, d_idx)
- elif x_tensor.layout == torch._mkldnn:
+ elif x_tensor.layout == torch._mkldnn: # type: ignore
# Use .data here to get around the version check
x_tensor = x_tensor.data
if len(input) != 1:
@@ -163,7 +164,7 @@
if output.is_sparse:
raise ValueError('Sparse output is not supported at gradcheck yet. '
'Please call to_dense() on the output of fn for gradcheck.')
- if output.layout == torch._mkldnn:
+ if output.layout == torch._mkldnn: # type: ignore
raise ValueError('MKLDNN output is not supported at gradcheck yet. '
'Please call to_dense() on the output of fn for gradcheck.')
diff_input_list = list(iter_tensors(input, True))
@@ -303,13 +304,13 @@
content = inp._values() if inp.is_sparse else inp
# TODO: To cover more problematic cases, replace stride = 0 check with
# "any overlap in memory" once we have a proper function to check it.
- if content.layout is not torch._mkldnn and \
- not all(st > 0 or sz <= 1 for st, sz in zip(content.stride(), content.size())):
- raise RuntimeError(
- 'The {}th input has a dimension with stride 0. gradcheck only '
- 'supports inputs that are non-overlapping to be able to '
- 'compute the numerical gradients correctly. You should call '
- '.contiguous on the input before passing it to gradcheck.')
+ if content.layout is not torch._mkldnn: # type: ignore
+ if not all(st > 0 or sz <= 1 for st, sz in zip(content.stride(), content.size())):
+ raise RuntimeError(
+ 'The {}th input has a dimension with stride 0. gradcheck only '
+ 'supports inputs that are non-overlapping to be able to '
+ 'compute the numerical gradients correctly. You should call '
+ '.contiguous on the input before passing it to gradcheck.')
any_input_requiring_grad = True
inp.retain_grad()
if not any_input_requiring_grad:
@@ -403,30 +404,30 @@
# check if the backward multiplies by grad_output
output = _differentiable_outputs(func(*tupled_inputs))
if any([o.requires_grad for o in output]):
- diff_input_list = list(iter_tensors(tupled_inputs, True))
+ diff_input_list: List[torch.Tensor] = list(iter_tensors(tupled_inputs, True))
if not diff_input_list:
raise RuntimeError("no Tensors requiring grad found in input")
grads_input = torch.autograd.grad(output, diff_input_list,
[torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) for o in output],
allow_unused=True)
- for gi, i in zip(grads_input, diff_input_list):
+ for gi, di in zip(grads_input, diff_input_list):
if gi is None:
continue
if isinstance(gi, torch.Tensor) and gi.layout != torch.strided:
- if gi.layout != i.layout:
- return fail_test('grad is incorrect layout (' + str(gi.layout) + ' is not ' + str(i.layout) + ')')
+ if gi.layout != di.layout:
+ return fail_test('grad is incorrect layout (' + str(gi.layout) + ' is not ' + str(di.layout) + ')')
if gi.layout == torch.sparse_coo:
- if gi.sparse_dim() != i.sparse_dim():
+ if gi.sparse_dim() != di.sparse_dim():
return fail_test('grad is sparse tensor, but has incorrect sparse_dim')
- if gi.dense_dim() != i.dense_dim():
+ if gi.dense_dim() != di.dense_dim():
return fail_test('grad is sparse tensor, but has incorrect dense_dim')
gi = gi.to_dense()
- i = i.to_dense()
+ di = di.to_dense()
if not gi.eq(0).all():
return fail_test('backward not multiplied by grad_output')
- if gi.dtype != i.dtype or gi.device != i.device or gi.is_sparse != i.is_sparse:
+ if gi.dtype != di.dtype or gi.device != di.device or gi.is_sparse != di.is_sparse:
return fail_test("grad is incorrect type")
- if gi.size() != i.size():
+ if gi.size() != di.size():
return fail_test('grad is incorrect size')
if check_undefined_grad:
diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py
index 8d33be0..eba7368 100644
--- a/torch/autograd/profiler.py
+++ b/torch/autograd/profiler.py
@@ -6,6 +6,8 @@
from collections import defaultdict, namedtuple
from operator import attrgetter
+from typing import List, Dict, Tuple, Optional
+
try:
# Available in Python >= 3.2
from contextlib import ContextDecorator
@@ -13,6 +15,13 @@
import functools
class ContextDecorator(object): # type: ignore[no-redef]
+
+ def __enter__(self):
+ raise NotImplementedError
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ raise NotImplementedError
+
def __call__(self, func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
@@ -78,13 +87,13 @@
# Algorithm has O(N * log(N)) complexity where N is number of
# intervals
for thread_id, thread_events in threads:
- thread_events = sorted(
+ thread_events_ = sorted(
thread_events,
key=lambda event: [event.cpu_interval.start, -event.cpu_interval.end],
)
- current_events = []
+ current_events: List[FunctionEvent] = []
cur_end = 0
- for event in thread_events:
+ for event in thread_events_:
while len(current_events) > 0:
parent = current_events[-1]
if event.cpu_interval.start >= parent.cpu_interval.end or \
@@ -253,7 +262,7 @@
An EventList containing FunctionEventAvg objects.
"""
self.populate_cpu_children()
- stats = defaultdict(FunctionEventAvg)
+ stats: Dict[Tuple[int, Tuple[int, int]], FunctionEventAvg] = defaultdict(FunctionEventAvg)
def get_key(event, group_by_input_shapes, group_by_stack_n):
key = [str(event.key), str(event.node_id)]
@@ -413,6 +422,7 @@
def table(self, sort_by=None, row_limit=100, header=None, top_level_events_only=False):
self._check_finish()
+ assert self.function_events is not None
return self.function_events.table(
sort_by=sort_by, row_limit=row_limit, header=header,
top_level_events_only=top_level_events_only
@@ -421,16 +431,19 @@
def export_chrome_trace(self, path):
self._check_finish()
+ assert self.function_events is not None
return self.function_events.export_chrome_trace(path)
export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
self._check_finish()
+ assert self.function_events is not None
return self.function_events.key_averages(group_by_input_shape, group_by_stack_n)
key_averages.__doc__ = EventList.key_averages.__doc__
def total_average(self):
self._check_finish()
+ assert self.function_events is not None
return self.function_events.total_average()
total_average.__doc__ = EventList.total_average.__doc__
@@ -440,6 +453,7 @@
all self times across all the events.
"""
self._check_finish()
+ assert self.function_events is not None
return self.function_events.self_cpu_time_total
@@ -694,11 +708,11 @@
@property
def cpu_time(self):
- return 0.0 if self.count == 0 else 1.0 * self.cpu_time_total / self.count
+ return 0.0 if self.count == 0 else 1.0 * self.cpu_time_total / self.count # type: ignore
@property
def cuda_time(self):
- return 0.0 if self.count == 0 else 1.0 * self.cuda_time_total / self.count
+ return 0.0 if self.count == 0 else 1.0 * self.cuda_time_total / self.count # type: ignore
class Interval(object):
@@ -719,24 +733,24 @@
self, id, node_id, name, thread, cpu_start, cpu_end, fwd_thread=None, input_shapes=None,
stack=None, scope=0, cpu_memory_usage=0, cuda_memory_usage=0, is_async=False,
is_remote=True, sequence_nr=-1):
- self.id = id
- self.node_id = node_id
- self.name = name
- self.cpu_interval = Interval(cpu_start, cpu_end)
- self.thread = thread
- self.fwd_thread = fwd_thread
- self.kernels = []
- self.count = 1
- self.cpu_children = []
- self.cpu_parent = None
- self.input_shapes = input_shapes
- self.stack = stack
- self.scope = scope
- self.cpu_memory_usage = cpu_memory_usage
- self.cuda_memory_usage = cuda_memory_usage
- self.is_async = is_async
- self.is_remote = is_remote
- self.sequence_nr = sequence_nr
+ self.id: int = id
+ self.node_id: int = node_id
+ self.name: str = name
+ self.cpu_interval: Interval = Interval(cpu_start, cpu_end)
+ self.thread: int = thread
+ self.fwd_thread: Optional[int] = fwd_thread
+ self.kernels: List[Kernel] = []
+ self.count: int = 1
+ self.cpu_children: List[FunctionEvent] = []
+ self.cpu_parent: Optional[FunctionEvent] = None
+ self.input_shapes: Tuple[int, ...] = input_shapes
+ self.stack: List = stack
+ self.scope: int = scope
+ self.cpu_memory_usage: int = cpu_memory_usage
+ self.cuda_memory_usage: int = cuda_memory_usage
+ self.is_async: bool = is_async
+ self.is_remote: bool = is_remote
+ self.sequence_nr: int = sequence_nr
def append_kernel(self, name, device, start, end):
self.kernels.append(Kernel(name, device, Interval(start, end)))
@@ -830,24 +844,24 @@
class FunctionEventAvg(FormattedTimesMixin):
"""Used to average stats over multiple FunctionEvent objects."""
def __init__(self):
- self.key = None
- self.count = 0
- self.node_id = 0
- self.is_async = False
- self.is_remote = False
- self.cpu_time_total = 0
- self.cuda_time_total = 0
- self.self_cpu_time_total = 0
- self.self_cuda_time_total = 0
- self.input_shapes = None
- self.stack = None
- self.scope = None
- self.cpu_memory_usage = 0
- self.cuda_memory_usage = 0
- self.self_cpu_memory_usage = 0
- self.self_cuda_memory_usage = 0
- self.cpu_children = None
- self.cpu_parent = None
+ self.key: Optional[str] = None
+ self.count: int = 0
+ self.node_id: int = 0
+ self.is_async: bool = False
+ self.is_remote: bool = False
+ self.cpu_time_total: int = 0
+ self.cuda_time_total: int = 0
+ self.self_cpu_time_total: int = 0
+ self.self_cuda_time_total: int = 0
+ self.input_shapes: Optional[List[List[int]]] = None
+ self.stack: Optional[List] = None
+ self.scope: Optional[int] = None
+ self.cpu_memory_usage: int = 0
+ self.cuda_memory_usage: int = 0
+ self.self_cpu_memory_usage: int = 0
+ self.self_cuda_memory_usage: int = 0
+ self.cpu_children: Optional[List[FunctionEvent]] = None
+ self.cpu_parent: Optional[FunctionEvent] = None
def add(self, other):
if self.key is None:
@@ -950,6 +964,7 @@
# and the CPU time of the cuda start event for the device
def adjusted_time(cuda_record, cuda_records_map):
assert cuda_record.device() != -1
+ assert start_record is not None
cuda_time_0 = cuda_records_map[(cuda_record.node_id(), cuda_record.device())]
return cuda_time_0.cuda_elapsed_us(cuda_record) + start_record.cpu_elapsed_us(cuda_time_0)
@@ -1102,6 +1117,8 @@
for row in conn.execute(marker_query):
unique.see(row['marker_id'])
evt = FunctionEvent(id=row['marker_id'],
+ node_id=0, # missing a node_id when calling FunctionEvent. This is just to ensure
+ # that pytorch doesn't crash when creating a FunctionEvent() object
name=strings[row['name']],
cpu_start=row['start_time'],
cpu_end=row['end_time'],
@@ -1215,15 +1232,15 @@
# Have to use a list because nonlocal is Py3 only...
SPACING_SIZE = 2
- row_format = [""]
- header_sep = [""]
- line_length = [-SPACING_SIZE]
+ row_format_lst = [""]
+ header_sep_lst = [""]
+ line_length_lst = [-SPACING_SIZE]
MAX_STACK_ENTRY = 5
def add_column(padding, text_dir='>'):
- row_format[0] += '{: ' + text_dir + str(padding) + '}' + (' ' * SPACING_SIZE)
- header_sep[0] += '-' * padding + (' ' * SPACING_SIZE)
- line_length[0] += padding + SPACING_SIZE
+ row_format_lst[0] += '{: ' + text_dir + str(padding) + '}' + (' ' * SPACING_SIZE)
+ header_sep_lst[0] += '-' * padding + (' ' * SPACING_SIZE)
+ line_length_lst[0] += padding + SPACING_SIZE
add_column(name_column_width)
for _ in headers[1:]:
@@ -1237,10 +1254,10 @@
headers.append('Source Location')
add_column(src_column_width, text_dir='<')
- row_format = row_format[0]
- header_sep = header_sep[0]
- line_length = line_length[0]
- add_column = None
+ row_format = row_format_lst[0]
+ header_sep = header_sep_lst[0]
+ line_length = line_length_lst[0]
+ add_column = None # type: ignore
# Have to use a list because nonlocal is Py3 only...
result = []
diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py
index 1008d74..307f82d 100644
--- a/torch/autograd/variable.py
+++ b/torch/autograd/variable.py
@@ -7,9 +7,10 @@
return isinstance(other, torch.Tensor)
-class Variable(with_metaclass(VariableMeta, torch._C._LegacyVariableBase)):
+# mypy doesn't understand torch._six.with_metaclass
+class Variable(with_metaclass(VariableMeta, torch._C._LegacyVariableBase)): # type: ignore
pass
from torch._C import _ImperativeEngine as ImperativeEngine
-Variable._execution_engine = ImperativeEngine()
+Variable._execution_engine = ImperativeEngine() # type: ignore