Use torch.zeros for nn.LSTM
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16779
Differential Revision: D13963577
Pulled By: driazati
fbshipit-source-id: dc9edc3d2096760737ecbe4b3dd441ed2d53f4ad
diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py
index 771df23..8e90363 100644
--- a/torch/nn/modules/rnn.py
+++ b/torch/nn/modules/rnn.py
@@ -183,9 +183,9 @@
if hx is None:
num_directions = 2 if self.bidirectional else 1
- hx = input.new_zeros(self.num_layers * num_directions,
- max_batch_size, self.hidden_size,
- requires_grad=False)
+ hx = torch.zeros(self.num_layers * num_directions,
+ max_batch_size, self.hidden_size,
+ dtype=input.dtype, device=input.device)
if self.mode == 'LSTM':
hx = (hx, hx)
else: