| import warnings |
| from torch.autograd import NestedIOFunction |
| import torch.backends.cudnn as cudnn |
| from .. import functional as F |
| import torch |
| from .thnn import rnnFusedPointwise as fusedBackend |
| import itertools |
| from functools import partial |
| |
| try: |
| import torch.backends.cudnn.rnn |
| except ImportError: |
| pass |
| |
| |
| def RNNReLUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): |
| hy = F.relu(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh)) |
| return hy |
| |
| |
| def RNNTanhCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): |
| hy = torch.tanh(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh)) |
| return hy |
| |
| |
| def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): |
| if input.is_cuda: |
| igates = F.linear(input, w_ih) |
| hgates = F.linear(hidden[0], w_hh) |
| state = fusedBackend.LSTMFused.apply |
| return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh) |
| |
| hx, cx = hidden |
| gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) |
| |
| ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) |
| |
| ingate = torch.sigmoid(ingate) |
| forgetgate = torch.sigmoid(forgetgate) |
| cellgate = torch.tanh(cellgate) |
| outgate = torch.sigmoid(outgate) |
| |
| cy = (forgetgate * cx) + (ingate * cellgate) |
| hy = outgate * torch.tanh(cy) |
| |
| return hy, cy |
| |
| |
| def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): |
| |
| if input.is_cuda: |
| gi = F.linear(input, w_ih) |
| gh = F.linear(hidden, w_hh) |
| state = fusedBackend.GRUFused.apply |
| return state(gi, gh, hidden) if b_ih is None else state(gi, gh, hidden, b_ih, b_hh) |
| |
| gi = F.linear(input, w_ih, b_ih) |
| gh = F.linear(hidden, w_hh, b_hh) |
| i_r, i_i, i_n = gi.chunk(3, 1) |
| h_r, h_i, h_n = gh.chunk(3, 1) |
| |
| resetgate = torch.sigmoid(i_r + h_r) |
| inputgate = torch.sigmoid(i_i + h_i) |
| newgate = torch.tanh(i_n + resetgate * h_n) |
| hy = newgate + inputgate * (hidden - newgate) |
| |
| return hy |
| |
| |
| def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True): |
| |
| num_directions = len(inners) |
| total_layers = num_layers * num_directions |
| |
| def forward(input, hidden, weight, batch_sizes): |
| assert(len(weight) == total_layers) |
| next_hidden = [] |
| |
| if lstm: |
| hidden = list(zip(*hidden)) |
| |
| for i in range(num_layers): |
| all_output = [] |
| for j, inner in enumerate(inners): |
| l = i * num_directions + j |
| |
| hy, output = inner(input, hidden[l], weight[l], batch_sizes) |
| next_hidden.append(hy) |
| all_output.append(output) |
| |
| input = torch.cat(all_output, input.dim() - 1) |
| |
| if dropout != 0 and i < num_layers - 1: |
| input = F.dropout(input, p=dropout, training=train, inplace=False) |
| |
| if lstm: |
| next_h, next_c = zip(*next_hidden) |
| next_hidden = ( |
| torch.cat(next_h, 0).view(total_layers, *next_h[0].size()), |
| torch.cat(next_c, 0).view(total_layers, *next_c[0].size()) |
| ) |
| else: |
| next_hidden = torch.cat(next_hidden, 0).view( |
| total_layers, *next_hidden[0].size()) |
| |
| return next_hidden, input |
| |
| return forward |
| |
| |
| def Recurrent(inner, reverse=False): |
| def forward(input, hidden, weight, batch_sizes): |
| output = [] |
| steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0)) |
| for i in steps: |
| hidden = inner(input[i], hidden, *weight) |
| # hack to handle LSTM |
| output.append(hidden[0] if isinstance(hidden, tuple) else hidden) |
| |
| if reverse: |
| output.reverse() |
| output = torch.cat(output, 0).view(input.size(0), *output[0].size()) |
| |
| return hidden, output |
| |
| return forward |
| |
| |
| def variable_recurrent_factory(inner, reverse=False): |
| if reverse: |
| return VariableRecurrentReverse(inner) |
| else: |
| return VariableRecurrent(inner) |
| |
| |
| def VariableRecurrent(inner): |
| def forward(input, hidden, weight, batch_sizes): |
| |
| output = [] |
| input_offset = 0 |
| last_batch_size = batch_sizes[0] |
| hiddens = [] |
| flat_hidden = not isinstance(hidden, tuple) |
| if flat_hidden: |
| hidden = (hidden,) |
| for batch_size in batch_sizes: |
| step_input = input[input_offset:input_offset + batch_size] |
| input_offset += batch_size |
| |
| dec = last_batch_size - batch_size |
| if dec > 0: |
| hiddens.append(tuple(h[-dec:] for h in hidden)) |
| hidden = tuple(h[:-dec] for h in hidden) |
| last_batch_size = batch_size |
| |
| if flat_hidden: |
| hidden = (inner(step_input, hidden[0], *weight),) |
| else: |
| hidden = inner(step_input, hidden, *weight) |
| |
| output.append(hidden[0]) |
| hiddens.append(hidden) |
| hiddens.reverse() |
| |
| hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens)) |
| assert hidden[0].size(0) == batch_sizes[0] |
| if flat_hidden: |
| hidden = hidden[0] |
| output = torch.cat(output, 0) |
| |
| return hidden, output |
| |
| return forward |
| |
| |
| def VariableRecurrentReverse(inner): |
| def forward(input, hidden, weight, batch_sizes): |
| output = [] |
| input_offset = input.size(0) |
| last_batch_size = batch_sizes[-1] |
| initial_hidden = hidden |
| flat_hidden = not isinstance(hidden, tuple) |
| if flat_hidden: |
| hidden = (hidden,) |
| initial_hidden = (initial_hidden,) |
| hidden = tuple(h[:batch_sizes[-1]] for h in hidden) |
| for i in reversed(range(len(batch_sizes))): |
| batch_size = batch_sizes[i] |
| inc = batch_size - last_batch_size |
| if inc > 0: |
| hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0) |
| for h, ih in zip(hidden, initial_hidden)) |
| last_batch_size = batch_size |
| step_input = input[input_offset - batch_size:input_offset] |
| input_offset -= batch_size |
| |
| if flat_hidden: |
| hidden = (inner(step_input, hidden[0], *weight),) |
| else: |
| hidden = inner(step_input, hidden, *weight) |
| output.append(hidden[0]) |
| |
| output.reverse() |
| output = torch.cat(output, 0) |
| if flat_hidden: |
| hidden = hidden[0] |
| return hidden, output |
| |
| return forward |
| |
| |
| def AutogradRNN(mode, input_size, hidden_size, num_layers=1, batch_first=False, |
| dropout=0, train=True, bidirectional=False, variable_length=False, |
| dropout_state=None, flat_weight=None): |
| |
| if mode == 'RNN_RELU': |
| cell = RNNReLUCell |
| elif mode == 'RNN_TANH': |
| cell = RNNTanhCell |
| elif mode == 'LSTM': |
| cell = LSTMCell |
| elif mode == 'GRU': |
| cell = GRUCell |
| else: |
| raise Exception('Unknown mode: {}'.format(mode)) |
| |
| rec_factory = variable_recurrent_factory if variable_length else Recurrent |
| |
| if bidirectional: |
| layer = (rec_factory(cell), rec_factory(cell, reverse=True)) |
| else: |
| layer = (rec_factory(cell),) |
| |
| func = StackedRNN(layer, |
| num_layers, |
| (mode == 'LSTM'), |
| dropout=dropout, |
| train=train) |
| |
| def forward(input, weight, hidden, batch_sizes): |
| if batch_first and not variable_length: |
| input = input.transpose(0, 1) |
| |
| nexth, output = func(input, hidden, weight, batch_sizes) |
| |
| if batch_first and not variable_length: |
| output = output.transpose(0, 1) |
| |
| return output, nexth |
| |
| return forward |
| |
| |
| def CudnnRNN(mode, input_size, hidden_size, num_layers=1, |
| batch_first=False, dropout=0, train=True, bidirectional=False, |
| variable_length=False, dropout_state=None, flat_weight=None): |
| if dropout_state is None: |
| dropout_state = {} |
| mode = cudnn.rnn.get_cudnn_mode(mode) |
| # TODO: This is really goofy way of using the Torch RNG to get a random number |
| dropout_seed = int(torch.IntTensor(1).random_()) |
| if flat_weight is None: |
| warnings.warn("RNN module weights are not part of single contiguous " |
| "chunk of memory. This means they need to be compacted " |
| "at every call, possibly greatly increasing memory usage. " |
| "To compact weights again call flatten_parameters().", stacklevel=5) |
| |
| def forward(input, weight, hx, batch_sizes): |
| if mode == cudnn.CUDNN_LSTM: |
| hx, cx = hx |
| else: |
| cx = None |
| |
| handle = cudnn.get_handle() |
| with torch.cuda.device(input.get_device()): |
| dropout_ts = cudnn.rnn.init_dropout_state(dropout, train, dropout_seed, dropout_state) |
| |
| weight_arr = list(itertools.chain.from_iterable(weight)) |
| weight_stride0 = len(weight[0]) |
| |
| output, hy, cy, reserve, new_weight_buf = torch._cudnn_rnn( |
| input, weight_arr, weight_stride0, |
| flat_weight, |
| hx, cx, |
| mode, hidden_size, num_layers, |
| batch_first, dropout, train, bool(bidirectional), |
| list(batch_sizes.data) if variable_length else (), |
| dropout_ts) |
| |
| if cx is not None: |
| return (output, (hy, cy)) |
| else: |
| return (output, hy) |
| |
| return forward |
| |
| |
| def RNN(*args, **kwargs): |
| |
| def forward(input, *fargs, **fkwargs): |
| if cudnn.is_acceptable(input.data): |
| func = CudnnRNN(*args, **kwargs) |
| else: |
| func = AutogradRNN(*args, **kwargs) |
| |
| # Hack for the tracer that allows us to represent RNNs as single |
| # nodes and export them to ONNX in this form |
| # Check the first argument explicitly to reduce the overhead of creating |
| # the lambda. We need special handling here because the forward() |
| # function gets reconstructed each and every time when RNN() is invoked |
| # and we don't want to pay the cost of decorator invocation |
| import torch |
| if torch._C._get_tracing_state(): |
| import torch.onnx.symbolic |
| sym = torch.onnx.symbolic.RNN_symbolic_builder(*args, **kwargs) |
| cell_type = args[0] |
| |
| bound_symbolic = partial(torch.onnx.symbolic.rnn_trace_override_symbolic, |
| cell_type, func, sym) |
| |
| decorator = torch.onnx.symbolic_override(bound_symbolic) |
| func = decorator(func) |
| |
| return func(input, *fargs, **fkwargs) |
| |
| return forward |