blob: 0fbbaf2ea59f3611c64b5c4038273824a1555004 [file] [log] [blame]
import warnings
from torch.autograd import Function, NestedIOFunction, Variable
import torch.backends.cudnn as cudnn
from .. import functional as F
from .thnn import rnnFusedPointwise as fusedBackend
try:
import torch.backends.cudnn.rnn
except ImportError:
pass
def RNNReLUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
hy = F.relu(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh))
return hy
def RNNTanhCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
hy = F.tanh(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh))
return hy
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
if input.is_cuda:
igates = F.linear(input, w_ih)
hgates = F.linear(hidden[0], w_hh)
state = fusedBackend.LSTMFused()
return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh)
hx, cx = hidden
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * F.tanh(cy)
return hy, cy
def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
if input.is_cuda:
gi = F.linear(input, w_ih)
gh = F.linear(hidden, w_hh)
state = fusedBackend.GRUFused()
return state(gi, gh, hidden) if b_ih is None else state(gi, gh, hidden, b_ih, b_hh)
gi = F.linear(input, w_ih, b_ih)
gh = F.linear(hidden, w_hh, b_hh)
i_r, i_i, i_n = gi.chunk(3, 1)
h_r, h_i, h_n = gh.chunk(3, 1)
resetgate = F.sigmoid(i_r + h_r)
inputgate = F.sigmoid(i_i + h_i)
newgate = F.tanh(i_n + resetgate * h_n)
hy = newgate + inputgate * (hidden - newgate)
return hy
def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True):
num_directions = len(inners)
total_layers = num_layers * num_directions
def forward(input, hidden, weight):
assert(len(weight) == total_layers)
next_hidden = []
if lstm:
hidden = list(zip(*hidden))
for i in range(num_layers):
all_output = []
for j, inner in enumerate(inners):
l = i * num_directions + j
hy, output = inner(input, hidden[l], weight[l])
next_hidden.append(hy)
all_output.append(output)
input = torch.cat(all_output, input.dim() - 1)
if dropout != 0 and i < num_layers - 1:
input = F.dropout(input, p=dropout, training=train, inplace=False)
if lstm:
next_h, next_c = zip(*next_hidden)
next_hidden = (
torch.cat(next_h, 0).view(total_layers, *next_h[0].size()),
torch.cat(next_c, 0).view(total_layers, *next_c[0].size())
)
else:
next_hidden = torch.cat(next_hidden, 0).view(
total_layers, *next_hidden[0].size())
return next_hidden, input
return forward
def Recurrent(inner, reverse=False):
def forward(input, hidden, weight):
output = []
steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0))
for i in steps:
hidden = inner(input[i], hidden, *weight)
# hack to handle LSTM
output.append(hidden[0] if isinstance(hidden, tuple) else hidden)
if reverse:
output.reverse()
output = torch.cat(output, 0).view(input.size(0), *output[0].size())
return hidden, output
return forward
def variable_recurrent_factory(batch_sizes):
def fac(inner, reverse=False):
if reverse:
return VariableRecurrentReverse(batch_sizes, inner)
else:
return VariableRecurrent(batch_sizes, inner)
return fac
def VariableRecurrent(batch_sizes, inner):
def forward(input, hidden, weight):
output = []
input_offset = 0
last_batch_size = batch_sizes[0]
hiddens = []
flat_hidden = not isinstance(hidden, tuple)
if flat_hidden:
hidden = (hidden,)
for batch_size in batch_sizes:
step_input = input[input_offset:input_offset + batch_size]
input_offset += batch_size
dec = last_batch_size - batch_size
if dec > 0:
hiddens.append(tuple(h[-dec:] for h in hidden))
hidden = tuple(h[:-dec] for h in hidden)
last_batch_size = batch_size
if flat_hidden:
hidden = (inner(step_input, hidden[0], *weight),)
else:
hidden = inner(step_input, hidden, *weight)
output.append(hidden[0])
hiddens.append(hidden)
hiddens.reverse()
hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens))
assert hidden[0].size(0) == batch_sizes[0]
if flat_hidden:
hidden = hidden[0]
output = torch.cat(output, 0)
return hidden, output
return forward
def VariableRecurrentReverse(batch_sizes, inner):
def forward(input, hidden, weight):
output = []
input_offset = input.size(0)
last_batch_size = batch_sizes[-1]
initial_hidden = hidden
flat_hidden = not isinstance(hidden, tuple)
if flat_hidden:
hidden = (hidden,)
initial_hidden = (initial_hidden,)
hidden = tuple(h[:batch_sizes[-1]] for h in hidden)
for batch_size in reversed(batch_sizes):
inc = batch_size - last_batch_size
if inc > 0:
hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0)
for h, ih in zip(hidden, initial_hidden))
last_batch_size = batch_size
step_input = input[input_offset - batch_size:input_offset]
input_offset -= batch_size
if flat_hidden:
hidden = (inner(step_input, hidden[0], *weight),)
else:
hidden = inner(step_input, hidden, *weight)
output.append(hidden[0])
output.reverse()
output = torch.cat(output, 0)
if flat_hidden:
hidden = hidden[0]
return hidden, output
return forward
def AutogradRNN(mode, input_size, hidden_size, num_layers=1, batch_first=False,
dropout=0, train=True, bidirectional=False, batch_sizes=None,
dropout_state=None, flat_weight=None):
if mode == 'RNN_RELU':
cell = RNNReLUCell
elif mode == 'RNN_TANH':
cell = RNNTanhCell
elif mode == 'LSTM':
cell = LSTMCell
elif mode == 'GRU':
cell = GRUCell
else:
raise Exception('Unknown mode: {}'.format(mode))
if batch_sizes is None:
rec_factory = Recurrent
else:
rec_factory = variable_recurrent_factory(batch_sizes)
if bidirectional:
layer = (rec_factory(cell), rec_factory(cell, reverse=True))
else:
layer = (rec_factory(cell),)
func = StackedRNN(layer,
num_layers,
(mode == 'LSTM'),
dropout=dropout,
train=train)
def forward(input, weight, hidden):
if batch_first and batch_sizes is None:
input = input.transpose(0, 1)
nexth, output = func(input, hidden, weight)
if batch_first and batch_sizes is None:
output = output.transpose(0, 1)
return output, nexth
return forward
class CudnnRNN(NestedIOFunction):
def __init__(self, mode, input_size, hidden_size, num_layers=1,
batch_first=False, dropout=0, train=True, bidirectional=False,
batch_sizes=None, dropout_state=None, flat_weight=None):
super(CudnnRNN, self).__init__()
if dropout_state is None:
dropout_state = {}
self.mode = cudnn.rnn.get_cudnn_mode(mode)
self.input_mode = cudnn.CUDNN_LINEAR_INPUT
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.batch_first = batch_first
self.dropout = dropout
self.train = train
self.bidirectional = 1 if bidirectional else 0
self.num_directions = 2 if bidirectional else 1
self.batch_sizes = batch_sizes
self.dropout_seed = torch.IntTensor(1).random_()[0]
self.dropout_state = dropout_state
self.weight_buf = flat_weight
if flat_weight is None:
warnings.warn("RNN module weights are not part of single contiguous "
"chunk of memory. This means they need to be compacted "
"at every call, possibly greately increasing memory usage. "
"To compact weights again call flatten_parameters().", stacklevel=5)
def forward_extended(self, input, weight, hx):
assert cudnn.is_acceptable(input)
# TODO: raise a warning if weight_data_ptr is None
output = input.new()
if torch.is_tensor(hx):
hy = hx.new()
else:
hy = tuple(h.new() for h in hx)
cudnn.rnn.forward(self, input, hx, weight, output, hy)
self.save_for_backward(input, hx, weight, output)
return output, hy
def backward_extended(self, grad_output, grad_hy):
input, hx, weight, output = self.saved_tensors
input = input.contiguous()
grad_input, grad_weight, grad_hx = None, None, None
assert cudnn.is_acceptable(input)
grad_input = input.new()
if torch.is_tensor(hx):
grad_hx = input.new()
else:
grad_hx = tuple(h.new() for h in hx)
if self.retain_variables:
self._reserve_clone = self.reserve.clone()
cudnn.rnn.backward_grad(
self,
input,
hx,
weight,
output,
grad_output,
grad_hy,
grad_input,
grad_hx)
if any(self.needs_input_grad[1:]):
grad_weight = [tuple(w.new().resize_as_(w) for w in layer_weight) for layer_weight in weight]
cudnn.rnn.backward_weight(
self,
input,
hx,
output,
weight,
grad_weight)
else:
grad_weight = [(None,) * len(layer_weight) for layer_weight in weight]
if self.retain_variables:
self.reserve = self._reserve_clone
del self._reserve_clone
return grad_input, grad_weight, grad_hx
def RNN(*args, **kwargs):
def forward(input, *fargs, **fkwargs):
if cudnn.is_acceptable(input.data):
func = CudnnRNN(*args, **kwargs)
else:
func = AutogradRNN(*args, **kwargs)
return func(input, *fargs, **fkwargs)
return forward