| from typing import Any, Callable, Union, Tuple, Sequence, Optional |
| from .. import Tensor |
| from .grad_mode import no_grad as no_grad, enable_grad as enable_grad, \ |
| set_grad_enabled as set_grad_enabled |
| |
| # TODO make Variable and Function more precise |
| class Variable: |
| ... |
| |
| class Function: |
| @staticmethod |
| def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: ... |
| @staticmethod |
| def backward(ctx: Any, *grad_outputs: Any) -> Any: ... |
| |
| class NestedIOFunction(Function): |
| # The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the |
| # superclass (Function) but are instance methods here, which mypy reports as incomptabile. |
| def backward(self, *gradients: Any) -> Any: ... # type: ignore |
| def forward(self, *args: Any) -> tuple: ... # type: ignore |
| def save_for_backward(self, *args: Any) -> None:... |
| def mark_dirty(self, *args: Any, **kwargs: Any) -> None:... |
| def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None: ... |
| def forward_extended(self, *input: Any) -> None:... |
| def backward_extended(self, *grad_output: Any) -> None: ... |
| |
| # 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment. |
| # If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted, |
| # the '...' first argument of Callabe can be replaced with VarArg(Tensor). |
| # For now, we permit any input. |
| def gradcheck(func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]], inputs: Union[Tensor, Tuple[Tensor, ...]], eps: float=..., atol: float=..., rtol: float=..., raise_exception: bool=..., check_sparse_nnz: bool=...) -> bool: ... |
| def gradgradcheck(func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]], inputs: Union[Tensor, Tuple[Tensor, ...]], eps: float=..., atol: float=..., rtol: float=..., gen_non_contig_grad_outputs: bool=..., raise_exception: bool=...) -> bool: ... |
| |
| class detect_anomaly: |
| def __enter__(self) -> None: ... |
| def __exit__(self, *args: Any) -> bool: ... |
| |
| class set_detect_anomaly: |
| def __init__(self, mode: bool) -> None: ... |
| def __enter__(self) -> None:... |
| def __exit__(self, *args: Any) -> bool: ... |
| |
| _TensorOrTensors = Union[Tensor, Sequence[Tensor]] |
| def backward(tensors: _TensorOrTensors, grad_tensors: Optional[_TensorOrTensors]=..., retain_graph: Optional[bool]=..., create_graph: bool=...) -> None: ... |
| def grad(outputs: _TensorOrTensors, inputs: _TensorOrTensors, grad_outputs: Optional[_TensorOrTensors]=..., retain_graph: Optional[bool]=..., create_graph: bool=..., only_inputs: bool=..., allow_unused: bool=...) -> Tuple[Tensor, ...]: ... |