Fix test_rnn_args_check (#8606)
test_rnn_args_check generates mismatched input_shape and hidden_shape
args. To do this, it changes a dimension of input_shape or hidden_shape
to have an incorrect size.
Before, the test was changing the size of a dimension to -1. However,
this is flawed because an input of size i.e. (6, -1, 2) is wrong.
This PR fixes it so that the test changes sizes of dimensions to
`bad_size = 7`. As long as none of the other sizes (input_size,
hidden_size, num_layers, batch_size) divide this, we don't have to worry
about that dimension being accidentally broadcasted into working.
diff --git a/test/test_nn.py b/test/test_nn.py
index 69f094c..dbe2c72 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -3682,6 +3682,7 @@
batch_size = 4
seq_len = 6
num_directions = 1
+ bad_size = 7 # prime number so that no size can divide it.
def test(input_shape, hidden_shape, mode):
for input, hidden in get_inputs(input_shape, hidden_shape, mode):
@@ -3691,10 +3692,10 @@
correct_input_shape = (seq_len, batch_size, input_size)
correct_hidden_shape = (num_layers * num_directions, batch_size, hidden_size)
- def update_tuple(tup, dim, delta):
- new_tup = list(tup)
- new_tup[dim] = delta
- return tuple(new_tup)
+ def update_shape(shape, dim, new_dim_size):
+ new_shape = list(shape)
+ new_shape[dim] = new_dim_size
+ return tuple(new_shape)
def get_inputs(input_shape, hidden_shape, mode):
'''returns list( tuple(input, hidden) )
@@ -3714,28 +3715,28 @@
rnn_modes = ['RNN', 'GRU', 'LSTM']
for mode in rnn_modes:
# Incorrect input batch size
- input_shape = update_tuple(correct_input_shape, 1, -1)
+ input_shape = update_shape(correct_input_shape, 1, bad_size)
hidden_shape = correct_hidden_shape
test(input_shape, hidden_shape, mode)
# Incorrect hidden batch size
input_shape = correct_input_shape
- hidden_shape = update_tuple(correct_hidden_shape, 1, -1)
+ hidden_shape = update_shape(correct_hidden_shape, 1, bad_size)
test(input_shape, hidden_shape, mode)
# Incorrect input size
- input_shape = update_tuple(correct_input_shape, 2, -1)
+ input_shape = update_shape(correct_input_shape, 2, bad_size)
hidden_shape = correct_hidden_shape
test(input_shape, hidden_shape, mode)
# Incorrect hidden size
input_shape = correct_input_shape
- hidden_shape = update_tuple(correct_hidden_shape, 2, -1)
+ hidden_shape = update_shape(correct_hidden_shape, 2, bad_size)
test(input_shape, hidden_shape, mode)
# Incorrect hidden[0]
input_shape = correct_input_shape
- hidden_shape = update_tuple(correct_hidden_shape, 0, -1)
+ hidden_shape = update_shape(correct_hidden_shape, 0, bad_size)
test(input_shape, hidden_shape, mode)
def test_rnn_initial_hidden_state(self):