[INTEL MKL] Added input name in the bfloat16 namescope interface so that it can create correct node name when it is needed. The default value of the name is '', which will make sure the change is backward compatible, and won't affect existing models which suse bfloat16 namescope
diff --git a/tensorflow/python/tpu/bfloat16.py b/tensorflow/python/tpu/bfloat16.py
index 70f7181..c0bade4 100644
--- a/tensorflow/python/tpu/bfloat16.py
+++ b/tensorflow/python/tpu/bfloat16.py
@@ -70,11 +70,11 @@
@tf_export(v1=['tpu.bfloat16_scope'])
@tf_contextlib.contextmanager
-def bfloat16_scope():
+def bfloat16_scope(name = ''):
"""Scope class for bfloat16 variables so that the model uses custom getter.
This enables variables to be read as bfloat16 type when using get_variable.
"""
with variable_scope.variable_scope(
- '', custom_getter=_get_custom_getter()) as varscope:
+ name, custom_getter=_get_custom_getter()) as varscope:
yield varscope
diff --git a/tensorflow/python/tpu/bfloat16_test.py b/tensorflow/python/tpu/bfloat16_test.py
index 78157ea..5e59efb 100644
--- a/tensorflow/python/tpu/bfloat16_test.py
+++ b/tensorflow/python/tpu/bfloat16_test.py
@@ -24,15 +24,49 @@
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.tpu import bfloat16
-
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
class BFloat16ScopeTest(test.TestCase):
- def testScopeName(self):
+ def testDefaultScopeName(self):
"""Test if name for the variable scope is propagated correctly."""
with bfloat16.bfloat16_scope() as bf:
self.assertEqual(bf.name, "")
+ def testCustomScopeName(self):
+ """Test if custom name for the variable scope is propagated correctly."""
+ name = 'bfloat16'
+ with bfloat16.bfloat16_scope('bfloat16') as bf:
+ self.assertEqual(bf.name, name)
+
+ def testVariableName(self):
+ """Test if custom name for the variable scope is propagated correctly."""
+ g = ops.Graph()
+ with g.as_default():
+ a = variables.Variable(2.2, name='var_a')
+ b = variables.Variable(3.3, name='var_b')
+ d = variables.Variable(4.4, name='var_b')
+ with g.name_scope('scope1'):
+ with bfloat16.bfloat16_scope("bf16"):
+ a = math_ops.cast(a, dtypes.bfloat16)
+ b = math_ops.cast(b, dtypes.bfloat16)
+ c = math_ops.add(a, b, name='addition')
+ with bfloat16.bfloat16_scope():
+ d = math_ops.cast(d, dtypes.bfloat16)
+ math_ops.add(c, d, name='addition')
+
+ g_ops = g.get_operations()
+ ops_name = []
+ for op in g_ops:
+ ops_name.append(str(op.name))
+
+ self.assertIn('scope1/bf16/addition', ops_name)
+ self.assertIn('scope1/bf16/Cast', ops_name)
+ self.assertIn('scope1/addition', ops_name)
+ self.assertIn('scope1/Cast', ops_name)
+
@test_util.run_deprecated_v1
def testRequestedDType(self):
"""Test if requested dtype is honored in the getter.