blob: 80ad8cbf278c38a3eb867b3d4afe61816cdab1fd [file] [log] [blame]
from typing import Any, Optional, TypeVar
from .common_types import _devices_t, _device_t
from ..modules import Module
from ... import device, Tensor
T_co = TypeVar('T_co', covariant=True)
class DataParallel(Module[T_co]):
module: Module = ...
device_ids: _devices_t = ...
dim: int = ...
output_device: _device_t = ...
src_device_obj: device = ...
def __init__(self, module: Module[T_co], device_ids: Optional[_devices_t] = ..., output_device: Optional[_device_t] = ...,
dim: int = ...) -> None: ...
def forward(self, *inputs: Any, **kwargs: Any) -> T_co: ...
def __call__(self, *inputs: Any, **kwargs: Any) -> T_co: ...
def data_parallel(module: Module, inputs: Any, device_ids: Optional[_devices_t] = ...,
output_device: Optional[_device_t] = ..., dim: int = ...,
module_kwargs: Optional[Any] = ...) -> Tensor: ...