Allow keras custom layers to make `training` arg as keyword only.

PiperOrigin-RevId: 360224577
Change-Id: I5dc803ebfd51cc66b820469b0a1f8860ba3270a6
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index f5474d1..f4ac960 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -2978,12 +2978,15 @@
     self.__class__._call_accepts_kwargs.fget.cache.pop(self, None)
 
     call_fn_args = self._call_fn_args
+    call_fn_args += self._call_full_argspec.kwonlyargs or []
     self._expects_training_arg = ('training' in call_fn_args or
                                   self._call_accepts_kwargs)
     # The default training arg will be any (non-None) default specified in the
     # method signature, or None if no value is specified.
-    self._default_training_arg = self._call_fn_arg_defaults.get(
-        'training')
+    call_fn_arg_defaults = self._call_fn_arg_defaults.copy()
+    call_fn_arg_defaults.update(self._call_full_argspec.kwonlydefaults or {})
+    self._default_training_arg = call_fn_arg_defaults.get('training')
+
     self._expects_mask_arg = ('mask' in call_fn_args or
                               self._call_accepts_kwargs)
 
diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py
index 4da6421..d5fec07 100644
--- a/tensorflow/python/keras/engine/base_layer_test.py
+++ b/tensorflow/python/keras/engine/base_layer_test.py
@@ -757,6 +757,88 @@
         else:
           return self._nested_layer(inputs) * 0.5
 
+    self._test_custom_layer_training_arg(
+        CustomLayerNoTrainingArg=CustomLayerNoTrainingArg,
+        CustomLayerDefaultTrainingMissing=CustomLayerDefaultTrainingMissing,
+        CustomLayerDefaultTrainingNone=CustomLayerDefaultTrainingNone,
+        CustomLayerDefaultTrainingFalse=CustomLayerDefaultTrainingFalse,
+        CustomLayerDefaultTrainingTrue=CustomLayerDefaultTrainingTrue)
+
+  @combinations.generate(combinations.combine(mode=['eager']))
+  def test_custom_layer_training_arg_kwargonly(self):
+    class CustomLayerNoTrainingArg(base_layer.Layer):
+
+      def __init__(self, nested_layer=None):
+        super(CustomLayerNoTrainingArg, self).__init__()
+        self._nested_layer = nested_layer or array_ops.identity
+
+      def call(self, inputs):
+        return self._nested_layer(inputs)
+
+    class CustomLayerDefaultTrainingMissing(base_layer.Layer):
+
+      def __init__(self, nested_layer=None):
+        super(CustomLayerDefaultTrainingMissing, self).__init__()
+        self._nested_layer = nested_layer or array_ops.identity
+
+      def call(self, inputs, *, training):
+        if training:
+          return self._nested_layer(inputs)
+        else:
+          return self._nested_layer(inputs) * 0.5
+
+    class CustomLayerDefaultTrainingNone(base_layer.Layer):
+
+      def __init__(self, nested_layer=None):
+        super(CustomLayerDefaultTrainingNone, self).__init__()
+        self._nested_layer = nested_layer or array_ops.identity
+
+      def call(self, inputs, *, training=None):
+        if training:
+          return self._nested_layer(inputs)
+        else:
+          return self._nested_layer(inputs) * 0.5
+
+    class CustomLayerDefaultTrainingFalse(base_layer.Layer):
+
+      def __init__(self, nested_layer=None):
+        super(CustomLayerDefaultTrainingFalse, self).__init__()
+        self._nested_layer = nested_layer or array_ops.identity
+
+      def call(self, inputs, *, training=False):
+        if training:
+          return self._nested_layer(inputs)
+        else:
+          return self._nested_layer(inputs) * 0.5
+
+    class CustomLayerDefaultTrainingTrue(base_layer.Layer):
+
+      def __init__(self, nested_layer=None):
+        super(CustomLayerDefaultTrainingTrue, self).__init__()
+        self._nested_layer = nested_layer or array_ops.identity
+
+      def call(self, inputs, *, training=True):
+        if training:
+          return self._nested_layer(inputs)
+        else:
+          return self._nested_layer(inputs) * 0.5
+
+    self._test_custom_layer_training_arg(
+        CustomLayerNoTrainingArg=CustomLayerNoTrainingArg,
+        CustomLayerDefaultTrainingMissing=CustomLayerDefaultTrainingMissing,
+        CustomLayerDefaultTrainingNone=CustomLayerDefaultTrainingNone,
+        CustomLayerDefaultTrainingFalse=CustomLayerDefaultTrainingFalse,
+        CustomLayerDefaultTrainingTrue=CustomLayerDefaultTrainingTrue)
+
+  def _test_custom_layer_training_arg(self,
+                                      # pylint: disable=invalid-name
+                                      CustomLayerNoTrainingArg,
+                                      CustomLayerDefaultTrainingMissing,
+                                      CustomLayerDefaultTrainingNone,
+                                      CustomLayerDefaultTrainingFalse,
+                                      CustomLayerDefaultTrainingTrue,
+                                      # pylint: enable=invalid-name
+                                      ):
     x = array_ops.ones(shape=(1, 1))
 
     # If the layer signature doesn't specify a default training arg,