Add type hints for a few random functions/classes

Adds type hints for a few functions/classes that we use in [TorchGeo](https://github.com/microsoft/torchgeo).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74171
Approved by: https://github.com/jbschlosser
diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py
index 331327e..552afa4 100644
--- a/torch/autograd/grad_mode.py
+++ b/torch/autograd/grad_mode.py
@@ -123,12 +123,12 @@
         >>> z.requires_grad
         False
     """
-    def __init__(self):
+    def __init__(self) -> None:
         if not torch._jit_internal.is_scripting():
             super().__init__()
         self.prev = False
 
-    def __enter__(self):
+    def __enter__(self) -> None:
         self.prev = torch.is_grad_enabled()
         torch.set_grad_enabled(False)
 
diff --git a/torch/distributed/pipeline/sync/batchnorm.py b/torch/distributed/pipeline/sync/batchnorm.py
index a0e144b..c58fe00 100644
--- a/torch/distributed/pipeline/sync/batchnorm.py
+++ b/torch/distributed/pipeline/sync/batchnorm.py
@@ -5,7 +5,7 @@
 # This source code is licensed under the BSD license found in the
 # LICENSE file in the root directory of this source tree.
 """Tracks the running statistics per mini-batch instead of micro-batch."""
-from typing import Optional, TypeVar, cast
+from typing import TypeVar, cast
 
 import torch
 from torch import Tensor, nn
@@ -35,7 +35,7 @@
         self,
         num_features: int,
         eps: float = 1e-5,
-        momentum: Optional[float] = 0.1,
+        momentum: float = 0.1,
         affine: bool = True,
         chunks: int = 1,
     ) -> None:
diff --git a/torch/functional.py b/torch/functional.py
index efb98c2..29a66f7 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -1,5 +1,5 @@
 from typing import (
-    Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING
+    List, Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING
 )
 
 import torch
@@ -9,7 +9,7 @@
 from .overrides import (
     has_torch_function, has_torch_function_unary, has_torch_function_variadic,
     handle_torch_function)
-from ._jit_internal import boolean_dispatch, List
+from ._jit_internal import boolean_dispatch
 from ._jit_internal import _overload as overload
 
 Tensor = torch.Tensor
@@ -137,7 +137,9 @@
 
 
 
-def split(tensor, split_size_or_sections, dim=0):
+def split(
+    tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int = 0
+) -> List[Tensor]:
     r"""Splits the tensor into chunks. Each chunk is a view of the original tensor.
 
     If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will
@@ -187,7 +189,7 @@
     return tensor.split(split_size_or_sections, dim)
 
 
-def einsum(*args):
+def einsum(*args: Any) -> Tensor:
     r"""einsum(equation, *operands) -> Tensor
 
     Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation
diff --git a/torch/hub.py b/torch/hub.py
index 6df86c6..995e678 100644
--- a/torch/hub.py
+++ b/torch/hub.py
@@ -10,10 +10,13 @@
 import warnings
 import zipfile
 from pathlib import Path
+from typing import Callable, Dict, Optional, Union
 from urllib.error import HTTPError
 from urllib.request import urlopen, Request
 from urllib.parse import urlparse  # noqa: F401
 
+import torch.nn as nn
+
 try:
     from tqdm.auto import tqdm  # automatically select proper tqdm submodule if available
 except ImportError:
@@ -645,7 +648,14 @@
     return torch.load(extracted_file, map_location=map_location)
 
 
-def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None):
+def load_state_dict_from_url(
+    url: str,
+    model_dir: Optional[str] = None,
+    map_location: Optional[Union[Callable[[str], str], Dict[str, str]]] = None,
+    progress: bool = True,
+    check_hash: bool = False,
+    file_name: Optional[str] = None
+) -> nn.Module:
     r"""Loads the Torch serialized object at the given URL.
 
     If downloaded file is a zip file, it will be automatically
diff --git a/torch/nn/init.py b/torch/nn/init.py
index aad2380..09d233e 100644
--- a/torch/nn/init.py
+++ b/torch/nn/init.py
@@ -363,7 +363,9 @@
     return fan_in if mode == 'fan_in' else fan_out
 
 
-def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
+def kaiming_uniform_(
+    tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
+):
     r"""Fills the input `Tensor` with values according to the method
     described in `Delving deep into rectifiers: Surpassing human-level
     performance on ImageNet classification` - He, K. et al. (2015), using a
@@ -410,7 +412,9 @@
         return tensor.uniform_(-bound, bound)
 
 
-def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
+def kaiming_normal_(
+    tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
+):
     r"""Fills the input `Tensor` with values according to the method
     described in `Delving deep into rectifiers: Surpassing human-level
     performance on ImageNet classification` - He, K. et al. (2015), using a
diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py
index 1551aba..65271eb 100644
--- a/torch/nn/modules/batchnorm.py
+++ b/torch/nn/modules/batchnorm.py
@@ -118,14 +118,14 @@
 class _BatchNorm(_NormBase):
     def __init__(
         self,
-        num_features,
-        eps=1e-5,
-        momentum=0.1,
-        affine=True,
-        track_running_stats=True,
+        num_features: int,
+        eps: float = 1e-5,
+        momentum: float = 0.1,
+        affine: bool = True,
+        track_running_stats: bool = True,
         device=None,
         dtype=None
-    ):
+    ) -> None:
         factory_kwargs = {'device': device, 'dtype': dtype}
         super(_BatchNorm, self).__init__(
             num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py
index 539ab8a..48aadd2 100644
--- a/torch/nn/modules/linear.py
+++ b/torch/nn/modules/linear.py
@@ -1,4 +1,5 @@
 import math
+from typing import Any
 
 import torch
 from torch import Tensor
@@ -29,7 +30,7 @@
         torch.Size([128, 20])
 
     """
-    def __init__(self, *args, **kwargs):
+    def __init__(self, *args: Any, **kwargs: Any) -> None:
         super(Identity, self).__init__()
 
     def forward(self, input: Tensor) -> Tensor: