Add type hints for a few random functions/classes
Adds type hints for a few functions/classes that we use in [TorchGeo](https://github.com/microsoft/torchgeo).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74171
Approved by: https://github.com/jbschlosser
diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py
index 331327e..552afa4 100644
--- a/torch/autograd/grad_mode.py
+++ b/torch/autograd/grad_mode.py
@@ -123,12 +123,12 @@
>>> z.requires_grad
False
"""
- def __init__(self):
+ def __init__(self) -> None:
if not torch._jit_internal.is_scripting():
super().__init__()
self.prev = False
- def __enter__(self):
+ def __enter__(self) -> None:
self.prev = torch.is_grad_enabled()
torch.set_grad_enabled(False)
diff --git a/torch/distributed/pipeline/sync/batchnorm.py b/torch/distributed/pipeline/sync/batchnorm.py
index a0e144b..c58fe00 100644
--- a/torch/distributed/pipeline/sync/batchnorm.py
+++ b/torch/distributed/pipeline/sync/batchnorm.py
@@ -5,7 +5,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Tracks the running statistics per mini-batch instead of micro-batch."""
-from typing import Optional, TypeVar, cast
+from typing import TypeVar, cast
import torch
from torch import Tensor, nn
@@ -35,7 +35,7 @@
self,
num_features: int,
eps: float = 1e-5,
- momentum: Optional[float] = 0.1,
+ momentum: float = 0.1,
affine: bool = True,
chunks: int = 1,
) -> None:
diff --git a/torch/functional.py b/torch/functional.py
index efb98c2..29a66f7 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -1,5 +1,5 @@
from typing import (
- Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING
+ List, Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING
)
import torch
@@ -9,7 +9,7 @@
from .overrides import (
has_torch_function, has_torch_function_unary, has_torch_function_variadic,
handle_torch_function)
-from ._jit_internal import boolean_dispatch, List
+from ._jit_internal import boolean_dispatch
from ._jit_internal import _overload as overload
Tensor = torch.Tensor
@@ -137,7 +137,9 @@
-def split(tensor, split_size_or_sections, dim=0):
+def split(
+ tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int = 0
+) -> List[Tensor]:
r"""Splits the tensor into chunks. Each chunk is a view of the original tensor.
If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will
@@ -187,7 +189,7 @@
return tensor.split(split_size_or_sections, dim)
-def einsum(*args):
+def einsum(*args: Any) -> Tensor:
r"""einsum(equation, *operands) -> Tensor
Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation
diff --git a/torch/hub.py b/torch/hub.py
index 6df86c6..995e678 100644
--- a/torch/hub.py
+++ b/torch/hub.py
@@ -10,10 +10,13 @@
import warnings
import zipfile
from pathlib import Path
+from typing import Callable, Dict, Optional, Union
from urllib.error import HTTPError
from urllib.request import urlopen, Request
from urllib.parse import urlparse # noqa: F401
+import torch.nn as nn
+
try:
from tqdm.auto import tqdm # automatically select proper tqdm submodule if available
except ImportError:
@@ -645,7 +648,14 @@
return torch.load(extracted_file, map_location=map_location)
-def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None):
+def load_state_dict_from_url(
+ url: str,
+ model_dir: Optional[str] = None,
+ map_location: Optional[Union[Callable[[str], str], Dict[str, str]]] = None,
+ progress: bool = True,
+ check_hash: bool = False,
+ file_name: Optional[str] = None
+) -> nn.Module:
r"""Loads the Torch serialized object at the given URL.
If downloaded file is a zip file, it will be automatically
diff --git a/torch/nn/init.py b/torch/nn/init.py
index aad2380..09d233e 100644
--- a/torch/nn/init.py
+++ b/torch/nn/init.py
@@ -363,7 +363,9 @@
return fan_in if mode == 'fan_in' else fan_out
-def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
+def kaiming_uniform_(
+ tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
+):
r"""Fills the input `Tensor` with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
@@ -410,7 +412,9 @@
return tensor.uniform_(-bound, bound)
-def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
+def kaiming_normal_(
+ tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
+):
r"""Fills the input `Tensor` with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py
index 1551aba..65271eb 100644
--- a/torch/nn/modules/batchnorm.py
+++ b/torch/nn/modules/batchnorm.py
@@ -118,14 +118,14 @@
class _BatchNorm(_NormBase):
def __init__(
self,
- num_features,
- eps=1e-5,
- momentum=0.1,
- affine=True,
- track_running_stats=True,
+ num_features: int,
+ eps: float = 1e-5,
+ momentum: float = 0.1,
+ affine: bool = True,
+ track_running_stats: bool = True,
device=None,
dtype=None
- ):
+ ) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(_BatchNorm, self).__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py
index 539ab8a..48aadd2 100644
--- a/torch/nn/modules/linear.py
+++ b/torch/nn/modules/linear.py
@@ -1,4 +1,5 @@
import math
+from typing import Any
import torch
from torch import Tensor
@@ -29,7 +30,7 @@
torch.Size([128, 20])
"""
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
super(Identity, self).__init__()
def forward(self, input: Tensor) -> Tensor: