Use torch.zeros for nn.LSTM

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16779

Differential Revision: D13963577

Pulled By: driazati

fbshipit-source-id: dc9edc3d2096760737ecbe4b3dd441ed2d53f4ad
diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py
index 771df23..8e90363 100644
--- a/torch/nn/modules/rnn.py
+++ b/torch/nn/modules/rnn.py
@@ -183,9 +183,9 @@
 
         if hx is None:
             num_directions = 2 if self.bidirectional else 1
-            hx = input.new_zeros(self.num_layers * num_directions,
-                                 max_batch_size, self.hidden_size,
-                                 requires_grad=False)
+            hx = torch.zeros(self.num_layers * num_directions,
+                             max_batch_size, self.hidden_size,
+                             dtype=input.dtype, device=input.device)
             if self.mode == 'LSTM':
                 hx = (hx, hx)
         else: