Add proper error message when user does not call "super.__init__()" in custom layer.
The current error message is quite obscure and surprising.
PiperOrigin-RevId: 290557761
Change-Id: Ibf91099bceb29b562e100e399aaf4757c74060d6
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 2f04b4a..71d3084 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -650,7 +650,11 @@
Raises:
ValueError: if the layer's `call` method returns None (an invalid value).
+ RuntimeError: if `super().__init__()` was not called in the constructor.
"""
+ if not hasattr(self, '_thread_local'):
+ raise RuntimeError(
+ 'You must call `super().__init__()` in the layer constructor.')
call_context = base_layer_utils.call_context()
input_list = nest.flatten(inputs)
diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py
index fa77088..98ec9d8 100644
--- a/tensorflow/python/keras/engine/base_layer_test.py
+++ b/tensorflow/python/keras/engine/base_layer_test.py
@@ -582,6 +582,17 @@
model = keras.Sequential(dense)
self.assertEqual(model.count_params(), 16 * 4 + 16)
+ def test_super_not_called(self):
+
+ class CustomLayerNotCallingSuper(keras.layers.Layer):
+
+ def __init__(self):
+ pass
+
+ layer = CustomLayerNotCallingSuper()
+ with self.assertRaisesRegexp(RuntimeError, 'You must call `super()'):
+ layer(np.random.random((10, 2)))
+
class SymbolicSupportTest(test.TestCase):