Standardize name scopes used during model construction.
PiperOrigin-RevId: 301713794
Change-Id: Ifa309e22955183968ad51c6989be5356b8266cc1
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 8ae529f..66de8e7 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -2083,7 +2083,16 @@
self._dtype_policy = policy.Policy(value)
def _name_scope(self):
- return self.name
+ name_scope = self.name
+ current_name_scope = ops.get_name_scope()
+ if current_name_scope:
+ name_scope = current_name_scope + '/' + name_scope
+ if name_scope:
+ # Note that the trailing `/` prevents autogenerated
+ # numerical suffixes to get appended. It will also fully reset
+ # nested name scope (i.e. the outer name scope has no effect).
+ name_scope += '/'
+ return name_scope
def _init_set_name(self, name, zero_based=True):
if not name:
diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py
index 94766fe..1999f31 100644
--- a/tensorflow/python/keras/engine/base_layer_test.py
+++ b/tensorflow/python/keras/engine/base_layer_test.py
@@ -936,6 +936,30 @@
self.assertEqual(layer.bias.name, 'MyName/bias:0')
self.assertEqual(layer.kernel.name, 'MyName/kernel:0')
+ def test_name_scope_functional_api(self):
+ inputs = input_layer.Input((3,))
+ layer = layers.Dense(10, name='MyName')
+ _ = layer(inputs)
+ self.assertEqual(layer.bias.name, 'MyName/bias:0')
+ self.assertEqual(layer.kernel.name, 'MyName/kernel:0')
+
+ def test_name_scope_functional_api_nested(self):
+
+ class NestedLayer(base_layer.Layer):
+
+ def __init__(self, name='OuterName'):
+ super(NestedLayer, self).__init__(name=name)
+ self.dense = layers.Dense(10, name='InnerName')
+
+ def call(self, inputs):
+ return self.dense(inputs)
+
+ inputs = input_layer.Input((3,))
+ layer = NestedLayer()
+ _ = layer(inputs)
+ self.assertEqual(layer.dense.bias.name, 'OuterName/InnerName/bias:0')
+ self.assertEqual(layer.dense.kernel.name, 'OuterName/InnerName/kernel:0')
+
def test_name_scope_sublayer(self):
class NameScopeTracker(base_layer.Layer):