blob: 192ed2e987bfb90a2486e7502841d19f035f06cf [file] [log] [blame]
import torch
from torch.autograd import Variable
from .module import Module
# TODO: check contiguous in THNN
class BatchNorm(Module):
expected_dim = 2
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
super(BatchNorm, self).__init__()
self.affine = affine
self.eps = eps
self.momentum = momentum
self.running_mean = torch.zeros(num_features)
self.running_var = torch.ones(num_features)
if self.affine:
self.weight = Variable(torch.Tensor(num_features))
self.bias = Variable(torch.Tensor(num_features))
self.reset_parameters()
else:
self.weight = None
self.bias = None
def reset_parameters(self):
if self.weight:
self.weight.data.uniform_()
if self.bias:
self.bias.data.zero_()
self.running_mean.zero_()
self.running_var.fill_(1)
def _checkInputDim(self, input):
if input.dim() != self.expected_dim:
raise RuntimeError('only mini-batch supported ({}D tensor), got {}D tensor instead'.format(self.expected_dim, input.dim()))
if input.size(1) != self.running_mean.nElement():
raise RuntimeError('got {}-feature tensor, expected {}'.format(input.size(1), self.running_mean.nElement()))
def forward(self, input):
self._checkInputDim(input)
args = (input,)
if self.weight is not None:
args = args + (self.weight, self.bias)
return self._backend.BatchNorm(self.running_mean,
self.running_var, self.train, self.momentum, self.eps)(*args)
class BatchNorm2d(BatchNorm):
expected_dim = 4