Updated docs for custom_gradient with examples
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py
index 2437e05..5a6b3cc 100644
--- a/tensorflow/python/ops/custom_gradient.py
+++ b/tensorflow/python/ops/custom_gradient.py
@@ -125,6 +125,45 @@
With this definition, the gradient at x=100 will be correctly evaluated as
1.0.
+ The variable `dy` is defined as the upstream gradient. i.e. the gradient from
+ all the layers or functions originating from this layer.
+
+ By chain rule we know that
+ `dy/dx = dy/x_0 * dx_0/dx_1 * ... * dx_i/dx_i+1 * ... * dx_n/dx`
+
+ In this case the gradient of our current function defined as `dx_i/dx_i+1 = (1 - 1 / (1 + e))`.
+ The upstream gradient `dy` would be `dx_i+1/dx_i+2 * dx_i+2/dx_i+3 * ... * dx_n/dx`.
+ The upstream gradient multiplied by the current gradient is then passed downstream.
+
+ In case the function takes multiple variables as input, the `grad` function must also return
+ the same number of variables. We take the function `z = x * y` as an example.
+
+ ```python
+ @tf.custom_gradient
+ def bar(x, y):
+ def grad(upstream):
+ dz_dx = y
+ dz_dy = x
+ return upstream * dz_dx, upstream * dz_dy
+
+ z = x * y
+
+ return z, grad
+
+ x = tf.constant(2.0, dtype=tf.float32)
+ y = tf.constant(3.0, dtype=tf.float32)
+
+ with tf.GradientTape(persistent=True) as tape:
+ tape.watch(x)
+ tape.watch(y)
+ z = bar(x, y)
+
+ tf.print(z) # Output: 6
+ tf.print(tape.gradient(z, x)) # Output: 3
+ tf.print(tape.gradient(z, y)) # Output: 2
+ tf.print(tape.gradient(x, y)) # Output: None
+ ```
+
Nesting custom gradients can lead to unintuitive results. The default
behavior does not correspond to n-th order derivatives. For example