RNN: avoid copy for gradients of inputs to the rnn cell and save more memory!
Summary:
This is pretty tricky to explain, but we can just use
backward_links. This way the whole cell would use a blob from the
states_grad tensor instead of having its own blob. This also should
save on memory a bit
Differential Revision: D4770798
fbshipit-source-id: 673f85b2c2fdf42c47feeaa24d1e2bf086f012f9
diff --git a/caffe2/python/operator_test/recurrent_network_test.py b/caffe2/python/operator_test/recurrent_network_test.py
index b56391f..cebccf0 100644
--- a/caffe2/python/operator_test/recurrent_network_test.py
+++ b/caffe2/python/operator_test/recurrent_network_test.py
@@ -375,9 +375,6 @@
)
inputs = [workspace.FetchBlob(name) for name in op.input]
- print(op.input)
- print(inputs)
-
self.assertReferenceChecks(
hu.cpu_do,
op,
diff --git a/caffe2/python/recurrent.py b/caffe2/python/recurrent.py
index 664cb1a..4d79ac0 100644
--- a/caffe2/python/recurrent.py
+++ b/caffe2/python/recurrent.py
@@ -136,7 +136,6 @@
cell_output = links[str(cell_input)]
forward_links.append((cell_input, state, 0))
forward_links.append((cell_output, state, 1))
- backward_links.append((cell_input + "_grad", states_grad, 0))
backward_links.append((cell_output + "_grad", states_grad, 1))
backward_cell_net.Proto().external_input.append(
@@ -150,16 +149,28 @@
recurrent_input_grad = cell_input + "_grad"
if not backward_blob_versions.get(recurrent_input_grad, 0):
# If nobody writes to this recurrent input gradient, we need
- # to perform a munual copy. This is a case if SumOp is being
- # used as first operator of the cell net
- backward_cell_net.Copy(
- backward_mapping[cell_input], recurrent_input_grad)
- # Similarly, we need to copy over gradient values for the parameters that
- # are added as ExternalInputs (excluding timestep) to the step net
+ # to make sure it gets to the states grad blob after all.
+ # We do this by using backward_links which triggers an alias
+ # This logic is being used for example in a SumOp case
+ backward_links.append(
+ (backward_mapping[cell_input], states_grad, 0))
+ else:
+ backward_links.append((cell_input + "_grad", states_grad, 0))
+
for reference in references:
+ # Similar to above, in a case of a SumOp we need to write our parameter
+ # gradient to an external blob. In this case we can be sure that
+ # reference + "_grad" is a correct parameter name as we know how
+ # RecurrentNetworkOp gradient schema looks like.
reference_grad = reference + "_grad"
if (reference in backward_mapping and
reference_grad != str(backward_mapping[reference])):
+ # We can use an Alias because after each timestep
+ # RNN op adds value from reference_grad into and _acc blob
+ # which accumulates gradients for corresponding parameter accross
+ # timesteps. Then in the end of RNN op these two are being
+ # swaped and reference_grad blob becomes a real blob instead of
+ # being an alias
backward_cell_net.Alias(
backward_mapping[reference], reference_grad)