Fix `layers.Average`, `layers.average` docs.
PiperOrigin-RevId: 290384498
Change-Id: Ia87ef9511f84852fa4515f8647ca01c97f31aa7f
diff --git a/tensorflow/python/keras/layers/merge.py b/tensorflow/python/keras/layers/merge.py
index be1e1a9..1deae97 100644
--- a/tensorflow/python/keras/layers/merge.py
+++ b/tensorflow/python/keras/layers/merge.py
@@ -305,9 +305,30 @@
class Average(_Merge):
"""Layer that averages a list of inputs.
- It takes as input a list of tensors,
- all of the same shape, and returns
+ It takes as input a list of tensors, all of the same shape, and returns
a single tensor (also of the same shape).
+
+ Example:
+
+ >>> x1 = np.ones((2, 2))
+ >>> x2 = np.zeros((2, 2))
+ >>> y = tf.keras.layers.Average()([x1, x2])
+ >>> y.numpy().tolist()
+ [[0.5, 0.5], [0.5, 0.5]]
+
+ Usage in a functional model:
+
+ >>> input1 = tf.keras.layers.Input(shape=(16,))
+ >>> x1 = tf.keras.layers.Dense(8, activation='relu')(input1)
+ >>> input2 = tf.keras.layers.Input(shape=(32,))
+ >>> x2 = tf.keras.layers.Dense(8, activation='relu')(input2)
+ >>> avg = tf.keras.layers.Average()([x1, x2])
+ >>> out = tf.keras.layers.Dense(4)(avg)
+ >>> model = tf.keras.models.Model(inputs=[input1, input2], outputs=out)
+
+ Raises:
+ ValueError: If there is a shape mismatch between the inputs and the shapes
+ cannot be broadcasted to match.
"""
def _merge_function(self, inputs):
@@ -645,12 +666,34 @@
def average(inputs, **kwargs):
"""Functional interface to the `tf.keras.layers.Average` layer.
+ Example:
+
+ >>> x1 = np.ones((2, 2))
+ >>> x2 = np.zeros((2, 2))
+ >>> y = tf.keras.layers.Average()([x1, x2])
+ >>> y.numpy().tolist()
+ [[0.5, 0.5], [0.5, 0.5]]
+
+ Usage in a functional model:
+
+ >>> input1 = tf.keras.layers.Input(shape=(16,))
+ >>> x1 = tf.keras.layers.Dense(8, activation='relu')(input1)
+ >>> input2 = tf.keras.layers.Input(shape=(32,))
+ >>> x2 = tf.keras.layers.Dense(8, activation='relu')(input2)
+ >>> avg = tf.keras.layers.Average()([x1, x2])
+ >>> out = tf.keras.layers.Dense(4)(avg)
+ >>> model = tf.keras.models.Model(inputs=[input1, input2], outputs=out)
+
Arguments:
inputs: A list of input tensors (at least 2).
**kwargs: Standard layer keyword arguments.
Returns:
A tensor, the average of the inputs.
+
+ Raises:
+ ValueError: If there is a shape mismatch between the inputs and the shapes
+ cannot be broadcasted to match.
"""
return Average(**kwargs)(inputs)