from torch.nn.modules import Module | |
from typing import Any, TypeVar | |
T_co = TypeVar('T_co', covariant=True) | |
class DistributedDataParallelCPU(Module[T_co]): | |
module: Module[T_co] = ... | |
needs_reduction: bool = ... | |
def __init__(self, module: Module[T_co]) -> None: ... | |
def sync_parameters(self) -> None: ... | |
# TODO doublecheck that this does return a T_co and not a list of T_cos | |
def forward(self, *inputs: Any, **kwargs: Any) -> T_co: ... |