| from torch.autograd import Function, NestedIOFunction, Variable |
| import torch.backends.cudnn as cudnn |
| from .. import functional as F |
| try: |
| import torch.backends.cudnn.rnn |
| except ImportError: |
| pass |
| |
| |
| def RNNReLUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): |
| hy = F.relu(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh)) |
| return hy |
| |
| |
| def RNNTanhCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): |
| hy = F.tanh(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh)) |
| return hy |
| |
| |
| def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): |
| hx, cx = hidden |
| gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) |
| ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) |
| |
| ingate = F.sigmoid(ingate) |
| forgetgate = F.sigmoid(forgetgate) |
| cellgate = F.tanh(cellgate) |
| outgate = F.sigmoid(outgate) |
| |
| cy = (forgetgate * cx) + (ingate * cellgate) |
| hy = outgate * F.tanh(cy) |
| |
| return hy, cy |
| |
| |
| def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): |
| gi = F.linear(input, w_ih, b_ih) |
| gh = F.linear(hidden, w_hh, b_hh) |
| i_r, i_i, i_n = gi.chunk(3, 1) |
| h_r, h_i, h_n = gh.chunk(3, 1) |
| |
| resetgate = F.sigmoid(i_r + h_r) |
| inputgate = F.sigmoid(i_i + h_i) |
| newgate = F.tanh(i_n + resetgate * h_n) |
| hy = newgate + inputgate * (hidden - newgate) |
| |
| return hy |
| |
| |
| def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True): |
| |
| num_directions = len(inners) |
| total_layers = num_layers * num_directions |
| |
| def forward(input, hidden, weight): |
| assert(len(weight) == total_layers) |
| next_hidden = [] |
| |
| if lstm: |
| hidden = list(zip(*hidden)) |
| |
| for i in range(num_layers): |
| all_output = [] |
| for j, inner in enumerate(inners): |
| l = i * num_directions + j |
| |
| hy, output = inner(input, hidden[l], weight[l]) |
| next_hidden.append(hy) |
| all_output.append(output) |
| |
| input = torch.cat(all_output, 2) |
| |
| if dropout != 0 and i < num_layers - 1: |
| input = F.dropout(input, p=dropout, training=train, inplace=False) |
| |
| if lstm: |
| next_h, next_c = zip(*next_hidden) |
| next_hidden = ( |
| torch.cat(next_h, 0).view(total_layers, *next_h[0].size()), |
| torch.cat(next_c, 0).view(total_layers, *next_c[0].size()) |
| ) |
| else: |
| next_hidden = torch.cat(next_hidden, 0).view( |
| total_layers, *next_hidden[0].size()) |
| |
| return next_hidden, input |
| |
| return forward |
| |
| |
| def Recurrent(inner, reverse=False): |
| def forward(input, hidden, weight): |
| output = [] |
| steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0)) |
| for i in steps: |
| hidden = inner(input[i], hidden, *weight) |
| # hack to handle LSTM |
| output.append(isinstance(hidden, tuple) and hidden[0] or hidden) |
| |
| if reverse: |
| output.reverse() |
| output = torch.cat(output, 0).view(input.size(0), *output[0].size()) |
| |
| return hidden, output |
| |
| return forward |
| |
| |
| def AutogradRNN(mode, input_size, hidden_size, num_layers=1, batch_first=False, |
| dropout=0, train=True, bidirectional=False, dropout_state=None): |
| |
| if mode == 'RNN_RELU': |
| cell = RNNReLUCell |
| elif mode == 'RNN_TANH': |
| cell = RNNTanhCell |
| elif mode == 'LSTM': |
| cell = LSTMCell |
| elif mode == 'GRU': |
| cell = GRUCell |
| else: |
| raise Exception('Unknown mode: {}'.format(mode)) |
| |
| if bidirectional: |
| layer = (Recurrent(cell), Recurrent(cell, reverse=True)) |
| else: |
| layer = (Recurrent(cell),) |
| |
| func = StackedRNN(layer, |
| num_layers, |
| (mode == 'LSTM'), |
| dropout=dropout, |
| train=train) |
| |
| def forward(input, weight, hidden): |
| if batch_first: |
| input = input.transpose(0, 1) |
| |
| nexth, output = func(input, hidden, weight) |
| |
| if batch_first: |
| output = output.transpose(0, 1) |
| |
| return output, nexth |
| |
| return forward |
| |
| |
| class CudnnRNN(NestedIOFunction): |
| |
| def __init__(self, mode, input_size, hidden_size, num_layers=1, |
| batch_first=False, dropout=0, train=True, bidirectional=False, |
| dropout_state=None): |
| super(CudnnRNN, self).__init__() |
| if dropout_state is None: |
| dropout_state = {} |
| self.mode = cudnn.rnn.get_cudnn_mode(mode) |
| self.input_mode = cudnn.CUDNN_LINEAR_INPUT |
| self.input_size = input_size |
| self.hidden_size = hidden_size |
| self.num_layers = num_layers |
| self.batch_first = batch_first |
| self.dropout = dropout |
| self.train = train |
| self.bidirectional = 1 if bidirectional else 0 |
| self.num_directions = 2 if bidirectional else 1 |
| self.dropout_seed = torch.IntTensor(1).random_()[0] |
| self.dropout_state = dropout_state |
| |
| def forward_extended(self, input, weight, hx): |
| |
| assert(cudnn.is_acceptable(input)) |
| |
| output = input.new() |
| |
| if torch.is_tensor(hx): |
| hy = hx.new() |
| else: |
| hy = tuple(h.new() for h in hx) |
| |
| cudnn.rnn.forward(self, input, hx, weight, output, hy) |
| |
| self.save_for_backward(input, hx, weight, output) |
| return output, hy |
| |
| def backward_extended(self, grad_output, grad_hy): |
| input, hx, weight, output = self.saved_tensors |
| |
| grad_input, grad_weight, grad_hx = None, None, None |
| |
| assert cudnn.is_acceptable(input) |
| |
| grad_input = input.new() |
| if torch.is_tensor(hx): |
| grad_hx = input.new() |
| else: |
| grad_hx = tuple(h.new() for h in hx) |
| |
| if self.retain_variables: |
| self._reserve_clone = self.reserve.clone() |
| |
| cudnn.rnn.backward_grad( |
| self, |
| input, |
| hx, |
| weight, |
| output, |
| grad_output, |
| grad_hy, |
| grad_input, |
| grad_hx) |
| |
| if any(self.needs_input_grad[1:]): |
| grad_weight = [tuple(w.new().resize_as_(w) for w in layer_weight) for layer_weight in weight] |
| cudnn.rnn.backward_weight( |
| self, |
| input, |
| hx, |
| output, |
| weight, |
| grad_weight) |
| |
| if self.retain_variables: |
| self.reserve = self._reserve_clone |
| del self._reserve_clone |
| |
| return grad_input, grad_weight, grad_hx |
| |
| |
| def RNN(*args, **kwargs): |
| def forward(input, *fargs, **fkwargs): |
| if cudnn.is_acceptable(input.data): |
| func = CudnnRNN(*args, **kwargs) |
| else: |
| func = AutogradRNN(*args, **kwargs) |
| return func(input, *fargs, **fkwargs) |
| |
| return forward |