Cast dropout to float in RNN (#21304)
Summary:
This solves the situation where, for example, someone instantiates LSTM with `dropout=0`, a Python integer. This works fine in Python, but JIT throws a type error because it expected float but got int
Resolves https://github.com/pytorch/lockdown/issues/65
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21304
Differential Revision: D15613153
Pulled By: jamesr66a
fbshipit-source-id: eabff76e3af3de0612583b37dbc5f7eab7e248a4
diff --git a/test/test_jit.py b/test/test_jit.py
index f226ef7..b92e87e 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -12684,7 +12684,7 @@
class M(torch.jit.ScriptModule):
def __init__(self):
super(M, self).__init__()
- self.rnn = nn.LSTM(2, 3, 2)
+ self.rnn = nn.LSTM(2, 3, 2, dropout=0)
@torch.jit.script_method
def forward(self, x, lengths, h0, c0):
@@ -12693,7 +12693,7 @@
class Eager(torch.nn.Module):
def __init__(self):
super(Eager, self).__init__()
- self.rnn = nn.LSTM(2, 3, 2)
+ self.rnn = nn.LSTM(2, 3, 2, dropout=0)
def forward(self, x, lengths, h0, c0):
return self.rnn(x, (h0, c0))[0]
diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py
index 5d705e2..f2c6f70 100644
--- a/torch/nn/modules/rnn.py
+++ b/torch/nn/modules/rnn.py
@@ -38,7 +38,7 @@
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
- self.dropout = dropout
+ self.dropout = float(dropout)
self.bidirectional = bidirectional
num_directions = 2 if bidirectional else 1