Allow keras custom layers to make `training` arg as keyword only.
PiperOrigin-RevId: 360224577
Change-Id: I5dc803ebfd51cc66b820469b0a1f8860ba3270a6
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index f5474d1..f4ac960 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -2978,12 +2978,15 @@
self.__class__._call_accepts_kwargs.fget.cache.pop(self, None)
call_fn_args = self._call_fn_args
+ call_fn_args += self._call_full_argspec.kwonlyargs or []
self._expects_training_arg = ('training' in call_fn_args or
self._call_accepts_kwargs)
# The default training arg will be any (non-None) default specified in the
# method signature, or None if no value is specified.
- self._default_training_arg = self._call_fn_arg_defaults.get(
- 'training')
+ call_fn_arg_defaults = self._call_fn_arg_defaults.copy()
+ call_fn_arg_defaults.update(self._call_full_argspec.kwonlydefaults or {})
+ self._default_training_arg = call_fn_arg_defaults.get('training')
+
self._expects_mask_arg = ('mask' in call_fn_args or
self._call_accepts_kwargs)
diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py
index 4da6421..d5fec07 100644
--- a/tensorflow/python/keras/engine/base_layer_test.py
+++ b/tensorflow/python/keras/engine/base_layer_test.py
@@ -757,6 +757,88 @@
else:
return self._nested_layer(inputs) * 0.5
+ self._test_custom_layer_training_arg(
+ CustomLayerNoTrainingArg=CustomLayerNoTrainingArg,
+ CustomLayerDefaultTrainingMissing=CustomLayerDefaultTrainingMissing,
+ CustomLayerDefaultTrainingNone=CustomLayerDefaultTrainingNone,
+ CustomLayerDefaultTrainingFalse=CustomLayerDefaultTrainingFalse,
+ CustomLayerDefaultTrainingTrue=CustomLayerDefaultTrainingTrue)
+
+ @combinations.generate(combinations.combine(mode=['eager']))
+ def test_custom_layer_training_arg_kwargonly(self):
+ class CustomLayerNoTrainingArg(base_layer.Layer):
+
+ def __init__(self, nested_layer=None):
+ super(CustomLayerNoTrainingArg, self).__init__()
+ self._nested_layer = nested_layer or array_ops.identity
+
+ def call(self, inputs):
+ return self._nested_layer(inputs)
+
+ class CustomLayerDefaultTrainingMissing(base_layer.Layer):
+
+ def __init__(self, nested_layer=None):
+ super(CustomLayerDefaultTrainingMissing, self).__init__()
+ self._nested_layer = nested_layer or array_ops.identity
+
+ def call(self, inputs, *, training):
+ if training:
+ return self._nested_layer(inputs)
+ else:
+ return self._nested_layer(inputs) * 0.5
+
+ class CustomLayerDefaultTrainingNone(base_layer.Layer):
+
+ def __init__(self, nested_layer=None):
+ super(CustomLayerDefaultTrainingNone, self).__init__()
+ self._nested_layer = nested_layer or array_ops.identity
+
+ def call(self, inputs, *, training=None):
+ if training:
+ return self._nested_layer(inputs)
+ else:
+ return self._nested_layer(inputs) * 0.5
+
+ class CustomLayerDefaultTrainingFalse(base_layer.Layer):
+
+ def __init__(self, nested_layer=None):
+ super(CustomLayerDefaultTrainingFalse, self).__init__()
+ self._nested_layer = nested_layer or array_ops.identity
+
+ def call(self, inputs, *, training=False):
+ if training:
+ return self._nested_layer(inputs)
+ else:
+ return self._nested_layer(inputs) * 0.5
+
+ class CustomLayerDefaultTrainingTrue(base_layer.Layer):
+
+ def __init__(self, nested_layer=None):
+ super(CustomLayerDefaultTrainingTrue, self).__init__()
+ self._nested_layer = nested_layer or array_ops.identity
+
+ def call(self, inputs, *, training=True):
+ if training:
+ return self._nested_layer(inputs)
+ else:
+ return self._nested_layer(inputs) * 0.5
+
+ self._test_custom_layer_training_arg(
+ CustomLayerNoTrainingArg=CustomLayerNoTrainingArg,
+ CustomLayerDefaultTrainingMissing=CustomLayerDefaultTrainingMissing,
+ CustomLayerDefaultTrainingNone=CustomLayerDefaultTrainingNone,
+ CustomLayerDefaultTrainingFalse=CustomLayerDefaultTrainingFalse,
+ CustomLayerDefaultTrainingTrue=CustomLayerDefaultTrainingTrue)
+
+ def _test_custom_layer_training_arg(self,
+ # pylint: disable=invalid-name
+ CustomLayerNoTrainingArg,
+ CustomLayerDefaultTrainingMissing,
+ CustomLayerDefaultTrainingNone,
+ CustomLayerDefaultTrainingFalse,
+ CustomLayerDefaultTrainingTrue,
+ # pylint: enable=invalid-name
+ ):
x = array_ops.ones(shape=(1, 1))
# If the layer signature doesn't specify a default training arg,