Add checking of kwonlyargs in tensorflow.python.keras.utils.generic_utils.has_arg
diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py
index 27cab6b..f26e6a6 100644
--- a/tensorflow/python/keras/utils/generic_utils.py
+++ b/tensorflow/python/keras/utils/generic_utils.py
@@ -476,7 +476,7 @@
arg_spec = tf_inspect.getfullargspec(fn)
if accept_all and arg_spec.varkw is not None:
return True
- return name in arg_spec.args
+ return name in arg_spec.args or name in arg_spec.kwonlyargs
@keras_export('keras.utils.Progbar')
diff --git a/tensorflow/python/keras/utils/generic_utils_test.py b/tensorflow/python/keras/utils/generic_utils_test.py
index 1f0c1e2..302a1b4 100644
--- a/tensorflow/python/keras/utils/generic_utils_test.py
+++ b/tensorflow/python/keras/utils/generic_utils_test.py
@@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function
+from functools import partial
+
import numpy as np
from tensorflow.python import keras
@@ -38,6 +40,11 @@
def f_x_kwargs(x, **kwargs):
_ = kwargs
return x
+
+ def f(a, b, c):
+ return a + b + c
+
+ partial_f = partial(f, b=1)
self.assertTrue(keras.utils.generic_utils.has_arg(
f_x, 'x', accept_all=False))
@@ -53,6 +60,8 @@
f_x_kwargs, 'y', accept_all=False))
self.assertTrue(keras.utils.generic_utils.has_arg(
f_x_kwargs, 'y', accept_all=True))
+ self.assertTrue(keras.utils.generic_utils.has_arg(
+ partial_f, 'c', accept_all=True))
class TestCustomObjectScope(test.TestCase):