blob: 48af064c9c47c6e708b0375ce97db9eafbcb741a [file] [log] [blame]
from torch.nn.modules import Module
from typing import Any, TypeVar
T_co = TypeVar('T_co', covariant=True)
class DistributedDataParallelCPU(Module[T_co]):
module: Module[T_co] = ...
needs_reduction: bool = ...
def __init__(self, module: Module[T_co]) -> None: ...
def sync_parameters(self) -> None: ...
# TODO doublecheck that this does return a T_co and not a list of T_cos
def forward(self, *inputs: Any, **kwargs: Any) -> T_co: ...