blob: 54f8d2b1256ee004702cac821fdca705f8ccd89f [file] [log] [blame]
import torch
from .Module import Module
from .utils import clear
class Dropout(Module):
def __init__(self, p=0.5, inplace=False):
super(Dropout, self).__init__()
self.p = p
self.inplace = inplace
self.train = True
self.noise = torch.Tensor()
def updateOutput(self, input):
if self.inplace:
self.output.set_(input)
else:
self.output.resizeAs_(input).copy_(input)
if self.p > 0 and self.train:
self.noise.resizeAs_(input)
self.noise.bernoulli_(1-self.p)
self.noise.div_(1-self.p)
self.output.mul_(self.noise)
return self.output
def updateGradInput(self, input, gradOutput):
if self.inplace:
self.gradInput.set_(gradOutput)
else:
self.gradInput.resizeAs_(gradOutput).copy_(gradOutput)
if self.p > 0 and self.train:
self.gradInput.mul_(self.noise) # simply mask the gradients with the noise vector
return self.gradInput
def setp(self, p):
self.p = p
def __repr__(self):
return super(Dropout, self).__repr__() + '({:.4f})'.format(self.p)
def clearState(self):
clear(self, 'noise')
return super(Dropout, self).clearState()