removes duplicate variable reference crash from pad_sequences (#4383)
diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py
index 7445c73..4de8550 100644
--- a/torch/nn/utils/rnn.py
+++ b/torch/nn/utils/rnn.py
@@ -176,10 +176,8 @@
prev_l = max_len
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
- batch_dim = 0
else:
out_dims = (max_len, len(sequences)) + trailing_dims
- batch_dim = 1
out_variable = Variable(sequences[0].data.new(*out_dims).zero_())
for i, variable in enumerate(sequences):
@@ -188,7 +186,12 @@
if prev_l < length:
raise ValueError("lengths array has to be sorted in decreasing order")
prev_l = length
- out_variable.select(batch_dim, i)[:length] = variable
+ # use index notation to prevent duplicate references to the variable
+ if batch_first:
+ out_variable[i, :length, ...] = variable
+ else:
+ out_variable[:length, i, ...] = variable
+
return out_variable