blob: 1cccb77b78d35f8d134b91880dff97527d14468a [file] [log] [blame]
import warnings
from torch.autograd import NestedIOFunction
import torch.backends.cudnn as cudnn
from .. import functional as F
import torch
from .thnn import rnnFusedPointwise as fusedBackend
import itertools
from functools import partial
try:
import torch.backends.cudnn.rnn
except ImportError:
pass
def RNNReLUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
hy = F.relu(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh))
return hy
def RNNTanhCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
hy = torch.tanh(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh))
return hy
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
if input.is_cuda:
igates = F.linear(input, w_ih)
hgates = F.linear(hidden[0], w_hh)
state = fusedBackend.LSTMFused.apply
return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh)
hx, cx = hidden
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return hy, cy
def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
if input.is_cuda:
gi = F.linear(input, w_ih)
gh = F.linear(hidden, w_hh)
state = fusedBackend.GRUFused.apply
return state(gi, gh, hidden) if b_ih is None else state(gi, gh, hidden, b_ih, b_hh)
gi = F.linear(input, w_ih, b_ih)
gh = F.linear(hidden, w_hh, b_hh)
i_r, i_i, i_n = gi.chunk(3, 1)
h_r, h_i, h_n = gh.chunk(3, 1)
resetgate = torch.sigmoid(i_r + h_r)
inputgate = torch.sigmoid(i_i + h_i)
newgate = torch.tanh(i_n + resetgate * h_n)
hy = newgate + inputgate * (hidden - newgate)
return hy
def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True):
num_directions = len(inners)
total_layers = num_layers * num_directions
def forward(input, hidden, weight, batch_sizes):
assert(len(weight) == total_layers)
next_hidden = []
if lstm:
hidden = list(zip(*hidden))
for i in range(num_layers):
all_output = []
for j, inner in enumerate(inners):
l = i * num_directions + j
hy, output = inner(input, hidden[l], weight[l], batch_sizes)
next_hidden.append(hy)
all_output.append(output)
input = torch.cat(all_output, input.dim() - 1)
if dropout != 0 and i < num_layers - 1:
input = F.dropout(input, p=dropout, training=train, inplace=False)
if lstm:
next_h, next_c = zip(*next_hidden)
next_hidden = (
torch.cat(next_h, 0).view(total_layers, *next_h[0].size()),
torch.cat(next_c, 0).view(total_layers, *next_c[0].size())
)
else:
next_hidden = torch.cat(next_hidden, 0).view(
total_layers, *next_hidden[0].size())
return next_hidden, input
return forward
def Recurrent(inner, reverse=False):
def forward(input, hidden, weight, batch_sizes):
output = []
steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0))
for i in steps:
hidden = inner(input[i], hidden, *weight)
# hack to handle LSTM
output.append(hidden[0] if isinstance(hidden, tuple) else hidden)
if reverse:
output.reverse()
output = torch.cat(output, 0).view(input.size(0), *output[0].size())
return hidden, output
return forward
def variable_recurrent_factory(inner, reverse=False):
if reverse:
return VariableRecurrentReverse(inner)
else:
return VariableRecurrent(inner)
def VariableRecurrent(inner):
def forward(input, hidden, weight, batch_sizes):
output = []
input_offset = 0
last_batch_size = batch_sizes[0]
hiddens = []
flat_hidden = not isinstance(hidden, tuple)
if flat_hidden:
hidden = (hidden,)
for batch_size in batch_sizes:
step_input = input[input_offset:input_offset + batch_size]
input_offset += batch_size
dec = last_batch_size - batch_size
if dec > 0:
hiddens.append(tuple(h[-dec:] for h in hidden))
hidden = tuple(h[:-dec] for h in hidden)
last_batch_size = batch_size
if flat_hidden:
hidden = (inner(step_input, hidden[0], *weight),)
else:
hidden = inner(step_input, hidden, *weight)
output.append(hidden[0])
hiddens.append(hidden)
hiddens.reverse()
hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens))
assert hidden[0].size(0) == batch_sizes[0]
if flat_hidden:
hidden = hidden[0]
output = torch.cat(output, 0)
return hidden, output
return forward
def VariableRecurrentReverse(inner):
def forward(input, hidden, weight, batch_sizes):
output = []
input_offset = input.size(0)
last_batch_size = batch_sizes[-1]
initial_hidden = hidden
flat_hidden = not isinstance(hidden, tuple)
if flat_hidden:
hidden = (hidden,)
initial_hidden = (initial_hidden,)
hidden = tuple(h[:batch_sizes[-1]] for h in hidden)
for i in reversed(range(len(batch_sizes))):
batch_size = batch_sizes[i]
inc = batch_size - last_batch_size
if inc > 0:
hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0)
for h, ih in zip(hidden, initial_hidden))
last_batch_size = batch_size
step_input = input[input_offset - batch_size:input_offset]
input_offset -= batch_size
if flat_hidden:
hidden = (inner(step_input, hidden[0], *weight),)
else:
hidden = inner(step_input, hidden, *weight)
output.append(hidden[0])
output.reverse()
output = torch.cat(output, 0)
if flat_hidden:
hidden = hidden[0]
return hidden, output
return forward
def AutogradRNN(mode, input_size, hidden_size, num_layers=1, batch_first=False,
dropout=0, train=True, bidirectional=False, variable_length=False,
dropout_state=None, flat_weight=None):
if mode == 'RNN_RELU':
cell = RNNReLUCell
elif mode == 'RNN_TANH':
cell = RNNTanhCell
elif mode == 'LSTM':
cell = LSTMCell
elif mode == 'GRU':
cell = GRUCell
else:
raise Exception('Unknown mode: {}'.format(mode))
rec_factory = variable_recurrent_factory if variable_length else Recurrent
if bidirectional:
layer = (rec_factory(cell), rec_factory(cell, reverse=True))
else:
layer = (rec_factory(cell),)
func = StackedRNN(layer,
num_layers,
(mode == 'LSTM'),
dropout=dropout,
train=train)
def forward(input, weight, hidden, batch_sizes):
if batch_first and not variable_length:
input = input.transpose(0, 1)
nexth, output = func(input, hidden, weight, batch_sizes)
if batch_first and not variable_length:
output = output.transpose(0, 1)
return output, nexth
return forward
def CudnnRNN(mode, input_size, hidden_size, num_layers=1,
batch_first=False, dropout=0, train=True, bidirectional=False,
variable_length=False, dropout_state=None, flat_weight=None):
if dropout_state is None:
dropout_state = {}
mode = cudnn.rnn.get_cudnn_mode(mode)
# TODO: This is really goofy way of using the Torch RNG to get a random number
dropout_seed = int(torch.IntTensor(1).random_())
if flat_weight is None:
warnings.warn("RNN module weights are not part of single contiguous "
"chunk of memory. This means they need to be compacted "
"at every call, possibly greatly increasing memory usage. "
"To compact weights again call flatten_parameters().", stacklevel=5)
def forward(input, weight, hx, batch_sizes):
if mode == cudnn.CUDNN_LSTM:
hx, cx = hx
else:
cx = None
handle = cudnn.get_handle()
with torch.cuda.device(input.get_device()):
dropout_ts = cudnn.rnn.init_dropout_state(dropout, train, dropout_seed, dropout_state)
weight_arr = list(itertools.chain.from_iterable(weight))
weight_stride0 = len(weight[0])
output, hy, cy, reserve, new_weight_buf = torch._cudnn_rnn(
input, weight_arr, weight_stride0,
flat_weight,
hx, cx,
mode, hidden_size, num_layers,
batch_first, dropout, train, bool(bidirectional),
list(batch_sizes.data) if variable_length else (),
dropout_ts)
if cx is not None:
return (output, (hy, cy))
else:
return (output, hy)
return forward
def RNN(*args, **kwargs):
def forward(input, *fargs, **fkwargs):
if cudnn.is_acceptable(input.data):
func = CudnnRNN(*args, **kwargs)
else:
func = AutogradRNN(*args, **kwargs)
# Hack for the tracer that allows us to represent RNNs as single
# nodes and export them to ONNX in this form
# Check the first argument explicitly to reduce the overhead of creating
# the lambda. We need special handling here because the forward()
# function gets reconstructed each and every time when RNN() is invoked
# and we don't want to pay the cost of decorator invocation
import torch
if torch._C._get_tracing_state():
import torch.onnx.symbolic
sym = torch.onnx.symbolic.RNN_symbolic_builder(*args, **kwargs)
cell_type = args[0]
bound_symbolic = partial(torch.onnx.symbolic.rnn_trace_override_symbolic,
cell_type, func, sym)
decorator = torch.onnx.symbolic_override(bound_symbolic)
func = decorator(func)
return func(input, *fargs, **fkwargs)
return forward