blob: becd74a29788d52b4e20c59a7eb27543f15c71e1 [file] [log] [blame]
from ... import Tensor
from .. import Parameter
from .module import Module
from typing import Any, Optional
class _BatchNorm(Module):
num_features: int = ...
eps: float = ...
momentum: float = ...
affine: bool = ...
track_running_stats: bool = ...
weight: Parameter = ...
bias: Parameter = ...
def __init__(self, num_features: int, eps: float = ..., momentum: float = ..., affine: bool = ...,
track_running_stats: bool = ...) -> None: ...
def reset_running_stats(self) -> None: ...
def reset_parameters(self) -> None: ...
def forward(self, input: Tensor) -> Tensor: ...
class BatchNorm1d(_BatchNorm): ...
class BatchNorm2d(_BatchNorm): ...
class BatchNorm3d(_BatchNorm): ...
class SyncBatchNorm(_BatchNorm):
# TODO set process_group to the write type once torch.distributed is stubbed
def __init__(self, num_features: int, eps: float = ..., momentum: float = ..., affine: bool = ...,
track_running_stats: bool = ..., process_group: Optional[Any] = ...) -> None: ...
def forward(self, input: Tensor) -> Tensor: ...