| import warnings |
| from torch.autograd import Function, NestedIOFunction, Variable |
| import torch.backends.cudnn as cudnn |
| from .. import functional as F |
| from .thnn import rnnFusedPointwise as fusedBackend |
| |
| try: |
| import torch.backends.cudnn.rnn |
| except ImportError: |
| pass |
| |
| |
| def RNNReLUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): |
| hy = F.relu(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh)) |
| return hy |
| |
| |
| def RNNTanhCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): |
| hy = F.tanh(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh)) |
| return hy |
| |
| |
| def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): |
| if input.is_cuda: |
| igates = F.linear(input, w_ih) |
| hgates = F.linear(hidden[0], w_hh) |
| state = fusedBackend.LSTMFused() |
| return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh) |
| |
| hx, cx = hidden |
| gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) |
| |
| ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) |
| |
| ingate = F.sigmoid(ingate) |
| forgetgate = F.sigmoid(forgetgate) |
| cellgate = F.tanh(cellgate) |
| outgate = F.sigmoid(outgate) |
| |
| cy = (forgetgate * cx) + (ingate * cellgate) |
| hy = outgate * F.tanh(cy) |
| |
| return hy, cy |
| |
| |
| def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): |
| |
| if input.is_cuda: |
| gi = F.linear(input, w_ih) |
| gh = F.linear(hidden, w_hh) |
| state = fusedBackend.GRUFused() |
| return state(gi, gh, hidden) if b_ih is None else state(gi, gh, hidden, b_ih, b_hh) |
| |
| gi = F.linear(input, w_ih, b_ih) |
| gh = F.linear(hidden, w_hh, b_hh) |
| i_r, i_i, i_n = gi.chunk(3, 1) |
| h_r, h_i, h_n = gh.chunk(3, 1) |
| |
| resetgate = F.sigmoid(i_r + h_r) |
| inputgate = F.sigmoid(i_i + h_i) |
| newgate = F.tanh(i_n + resetgate * h_n) |
| hy = newgate + inputgate * (hidden - newgate) |
| |
| return hy |
| |
| |
| def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True): |
| |
| num_directions = len(inners) |
| total_layers = num_layers * num_directions |
| |
| def forward(input, hidden, weight): |
| assert(len(weight) == total_layers) |
| next_hidden = [] |
| |
| if lstm: |
| hidden = list(zip(*hidden)) |
| |
| for i in range(num_layers): |
| all_output = [] |
| for j, inner in enumerate(inners): |
| l = i * num_directions + j |
| |
| hy, output = inner(input, hidden[l], weight[l]) |
| next_hidden.append(hy) |
| all_output.append(output) |
| |
| input = torch.cat(all_output, input.dim() - 1) |
| |
| if dropout != 0 and i < num_layers - 1: |
| input = F.dropout(input, p=dropout, training=train, inplace=False) |
| |
| if lstm: |
| next_h, next_c = zip(*next_hidden) |
| next_hidden = ( |
| torch.cat(next_h, 0).view(total_layers, *next_h[0].size()), |
| torch.cat(next_c, 0).view(total_layers, *next_c[0].size()) |
| ) |
| else: |
| next_hidden = torch.cat(next_hidden, 0).view( |
| total_layers, *next_hidden[0].size()) |
| |
| return next_hidden, input |
| |
| return forward |
| |
| |
| def Recurrent(inner, reverse=False): |
| def forward(input, hidden, weight): |
| output = [] |
| steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0)) |
| for i in steps: |
| hidden = inner(input[i], hidden, *weight) |
| # hack to handle LSTM |
| output.append(hidden[0] if isinstance(hidden, tuple) else hidden) |
| |
| if reverse: |
| output.reverse() |
| output = torch.cat(output, 0).view(input.size(0), *output[0].size()) |
| |
| return hidden, output |
| |
| return forward |
| |
| |
| def variable_recurrent_factory(batch_sizes): |
| def fac(inner, reverse=False): |
| if reverse: |
| return VariableRecurrentReverse(batch_sizes, inner) |
| else: |
| return VariableRecurrent(batch_sizes, inner) |
| return fac |
| |
| |
| def VariableRecurrent(batch_sizes, inner): |
| def forward(input, hidden, weight): |
| output = [] |
| input_offset = 0 |
| last_batch_size = batch_sizes[0] |
| hiddens = [] |
| flat_hidden = not isinstance(hidden, tuple) |
| if flat_hidden: |
| hidden = (hidden,) |
| for batch_size in batch_sizes: |
| step_input = input[input_offset:input_offset + batch_size] |
| input_offset += batch_size |
| |
| dec = last_batch_size - batch_size |
| if dec > 0: |
| hiddens.append(tuple(h[-dec:] for h in hidden)) |
| hidden = tuple(h[:-dec] for h in hidden) |
| last_batch_size = batch_size |
| |
| if flat_hidden: |
| hidden = (inner(step_input, hidden[0], *weight),) |
| else: |
| hidden = inner(step_input, hidden, *weight) |
| |
| output.append(hidden[0]) |
| hiddens.append(hidden) |
| hiddens.reverse() |
| |
| hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens)) |
| assert hidden[0].size(0) == batch_sizes[0] |
| if flat_hidden: |
| hidden = hidden[0] |
| output = torch.cat(output, 0) |
| |
| return hidden, output |
| |
| return forward |
| |
| |
| def VariableRecurrentReverse(batch_sizes, inner): |
| def forward(input, hidden, weight): |
| output = [] |
| input_offset = input.size(0) |
| last_batch_size = batch_sizes[-1] |
| initial_hidden = hidden |
| flat_hidden = not isinstance(hidden, tuple) |
| if flat_hidden: |
| hidden = (hidden,) |
| initial_hidden = (initial_hidden,) |
| hidden = tuple(h[:batch_sizes[-1]] for h in hidden) |
| for batch_size in reversed(batch_sizes): |
| inc = batch_size - last_batch_size |
| if inc > 0: |
| hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0) |
| for h, ih in zip(hidden, initial_hidden)) |
| last_batch_size = batch_size |
| step_input = input[input_offset - batch_size:input_offset] |
| input_offset -= batch_size |
| |
| if flat_hidden: |
| hidden = (inner(step_input, hidden[0], *weight),) |
| else: |
| hidden = inner(step_input, hidden, *weight) |
| output.append(hidden[0]) |
| |
| output.reverse() |
| output = torch.cat(output, 0) |
| if flat_hidden: |
| hidden = hidden[0] |
| return hidden, output |
| |
| return forward |
| |
| |
| def AutogradRNN(mode, input_size, hidden_size, num_layers=1, batch_first=False, |
| dropout=0, train=True, bidirectional=False, batch_sizes=None, |
| dropout_state=None, flat_weight=None): |
| |
| if mode == 'RNN_RELU': |
| cell = RNNReLUCell |
| elif mode == 'RNN_TANH': |
| cell = RNNTanhCell |
| elif mode == 'LSTM': |
| cell = LSTMCell |
| elif mode == 'GRU': |
| cell = GRUCell |
| else: |
| raise Exception('Unknown mode: {}'.format(mode)) |
| |
| if batch_sizes is None: |
| rec_factory = Recurrent |
| else: |
| rec_factory = variable_recurrent_factory(batch_sizes) |
| |
| if bidirectional: |
| layer = (rec_factory(cell), rec_factory(cell, reverse=True)) |
| else: |
| layer = (rec_factory(cell),) |
| |
| func = StackedRNN(layer, |
| num_layers, |
| (mode == 'LSTM'), |
| dropout=dropout, |
| train=train) |
| |
| def forward(input, weight, hidden): |
| if batch_first and batch_sizes is None: |
| input = input.transpose(0, 1) |
| |
| nexth, output = func(input, hidden, weight) |
| |
| if batch_first and batch_sizes is None: |
| output = output.transpose(0, 1) |
| |
| return output, nexth |
| |
| return forward |
| |
| |
| class CudnnRNN(NestedIOFunction): |
| |
| def __init__(self, mode, input_size, hidden_size, num_layers=1, |
| batch_first=False, dropout=0, train=True, bidirectional=False, |
| batch_sizes=None, dropout_state=None, flat_weight=None): |
| super(CudnnRNN, self).__init__() |
| if dropout_state is None: |
| dropout_state = {} |
| self.mode = cudnn.rnn.get_cudnn_mode(mode) |
| self.input_mode = cudnn.CUDNN_LINEAR_INPUT |
| self.input_size = input_size |
| self.hidden_size = hidden_size |
| self.num_layers = num_layers |
| self.batch_first = batch_first |
| self.dropout = dropout |
| self.train = train |
| self.bidirectional = 1 if bidirectional else 0 |
| self.num_directions = 2 if bidirectional else 1 |
| self.batch_sizes = batch_sizes |
| self.dropout_seed = torch.IntTensor(1).random_()[0] |
| self.dropout_state = dropout_state |
| self.weight_buf = flat_weight |
| if flat_weight is None: |
| warnings.warn("RNN module weights are not part of single contiguous " |
| "chunk of memory. This means they need to be compacted " |
| "at every call, possibly greately increasing memory usage. " |
| "To compact weights again call flatten_parameters().", stacklevel=5) |
| |
| def forward_extended(self, input, weight, hx): |
| assert cudnn.is_acceptable(input) |
| # TODO: raise a warning if weight_data_ptr is None |
| |
| output = input.new() |
| |
| if torch.is_tensor(hx): |
| hy = hx.new() |
| else: |
| hy = tuple(h.new() for h in hx) |
| |
| cudnn.rnn.forward(self, input, hx, weight, output, hy) |
| |
| self.save_for_backward(input, hx, weight, output) |
| return output, hy |
| |
| def backward_extended(self, grad_output, grad_hy): |
| input, hx, weight, output = self.saved_tensors |
| input = input.contiguous() |
| |
| grad_input, grad_weight, grad_hx = None, None, None |
| |
| assert cudnn.is_acceptable(input) |
| |
| grad_input = input.new() |
| if torch.is_tensor(hx): |
| grad_hx = input.new() |
| else: |
| grad_hx = tuple(h.new() for h in hx) |
| |
| if self.retain_variables: |
| self._reserve_clone = self.reserve.clone() |
| |
| cudnn.rnn.backward_grad( |
| self, |
| input, |
| hx, |
| weight, |
| output, |
| grad_output, |
| grad_hy, |
| grad_input, |
| grad_hx) |
| |
| if any(self.needs_input_grad[1:]): |
| grad_weight = [tuple(w.new().resize_as_(w) for w in layer_weight) for layer_weight in weight] |
| cudnn.rnn.backward_weight( |
| self, |
| input, |
| hx, |
| output, |
| weight, |
| grad_weight) |
| else: |
| grad_weight = [(None,) * len(layer_weight) for layer_weight in weight] |
| |
| if self.retain_variables: |
| self.reserve = self._reserve_clone |
| del self._reserve_clone |
| |
| return grad_input, grad_weight, grad_hx |
| |
| |
| def RNN(*args, **kwargs): |
| def forward(input, *fargs, **fkwargs): |
| if cudnn.is_acceptable(input.data): |
| func = CudnnRNN(*args, **kwargs) |
| else: |
| func = AutogradRNN(*args, **kwargs) |
| return func(input, *fargs, **fkwargs) |
| |
| return forward |