blob: 15a237d3b3a711ff0d4e3c7f4b9e8e2fac1cd2fa [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests Policies."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.mixed_precision.experimental import policy as mp_policy
from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.platform import test
from tensorflow.python.training.experimental import mixed_precision
@test_util.run_all_in_graph_and_eager_modes
class PolicyTest(test.TestCase):
"""Tests Policies."""
@testing_utils.enable_v2_dtype_behavior
def test_dtype_attributes(self):
policy = mp_policy.Policy('infer')
self.assertEqual(policy.compute_dtype, None)
self.assertEqual(policy.variable_dtype, None)
policy = mp_policy.Policy('infer_float32_vars')
self.assertEqual(policy.compute_dtype, None)
self.assertEqual(policy.variable_dtype, 'float32')
for dtype in 'int32', 'bool', 'float16', 'float32':
policy = mp_policy.Policy(dtype)
self.assertEqual(policy.compute_dtype, dtype)
self.assertEqual(policy.variable_dtype, dtype)
policy = mp_policy.Policy(dtype + '_with_float32_vars')
self.assertEqual(policy.compute_dtype, dtype)
self.assertEqual(policy.variable_dtype, 'float32')
@testing_utils.enable_v2_dtype_behavior
def test_repr(self):
for policy in ('infer', 'infer_with_float32_vars', 'float32',
'float16_with_float32_vars'):
self.assertEqual(repr(mp_policy.Policy(policy)),
'<Policy "%s">' % policy)
self.assertEqual(repr(mp_policy.Policy('float32_with_float32_vars')),
'<Policy "float32">')
@testing_utils.enable_v2_dtype_behavior
def test_policy_errors(self):
# Test passing invalid strings
expected_error = 'Cannot convert value %s to a mixed precision Policy.'
for invalid_policy in ('abc', 'abc_with_float32_vars',
'float32_with_float16_vars'):
with self.assertRaisesRegexp(ValueError,
expected_error % invalid_policy):
mp_policy.Policy(invalid_policy)
# Test passing a DType
with self.assertRaisesRegexp(TypeError,
"'name' must be a string, not a DType. "
"Instead, pass DType.name. Got: float16"):
mp_policy.Policy(dtypes.float16)
# Test passing a non-DType invalid type
with self.assertRaisesRegexp(TypeError,
"'name' must be a string, but got: 5"):
mp_policy.Policy(5)
@testing_utils.enable_v2_dtype_behavior
def test_with_input_dtype(self):
policy = mp_policy.with_input_dtype(mp_policy.Policy('infer'), 'float16')
self.assertEqual(policy.compute_dtype, 'float16')
self.assertEqual(policy.variable_dtype, 'float16')
policy = mp_policy.with_input_dtype(
mp_policy.Policy('infer_with_float32_vars'), 'float16')
self.assertEqual(policy.compute_dtype, 'float16')
self.assertEqual(policy.variable_dtype, 'float32')
policy = mp_policy.with_input_dtype(
mp_policy.Policy('infer_with_float32_vars'), 'float32')
self.assertEqual(policy.compute_dtype, 'float32')
self.assertEqual(policy.variable_dtype, 'float32')
@testing_utils.enable_v2_dtype_behavior
def test_global_policy(self):
if base_layer_utils.v2_dtype_behavior_enabled():
default_policy = 'float32'
else:
default_policy = 'infer'
self.assertEqual(mp_policy.global_policy().name, default_policy)
try:
mp_policy.set_policy('infer_with_float32_vars')
self.assertEqual(mp_policy.global_policy().name,
'infer_with_float32_vars')
with ops.Graph().as_default(): # Policies are not associated with a graph
self.assertEqual(mp_policy.global_policy().name,
'infer_with_float32_vars')
mp_policy.set_policy('infer')
self.assertEqual(mp_policy.global_policy().name, 'infer')
policy = mp_policy.Policy('infer_with_float32_vars')
mp_policy.set_policy(policy)
self.assertIs(mp_policy.global_policy(), policy)
finally:
mp_policy.set_policy(None)
@testing_utils.enable_v2_dtype_behavior
def test_policy_scope(self):
if base_layer_utils.v2_dtype_behavior_enabled():
default_policy = 'float32'
else:
default_policy = 'infer'
with mp_policy.policy_scope('infer_with_float32_vars'):
self.assertEqual(mp_policy.global_policy().name,
'infer_with_float32_vars')
with mp_policy.policy_scope('infer'):
self.assertEqual(mp_policy.global_policy().name, 'infer')
self.assertEqual(mp_policy.global_policy().name,
'infer_with_float32_vars')
self.assertEqual(mp_policy.global_policy().name, default_policy)
@testing_utils.enable_v2_dtype_behavior
def test_error_if_graph_rewrite_enabled(self):
try:
mixed_precision.enable_mixed_precision_graph_rewrite(
gradient_descent.SGD(1.))
with self.assertRaisesRegexp(
ValueError, 'the mixed precision graph rewrite has already been '
'enabled'):
mp_policy.set_policy('infer_float32_vars')
finally:
mixed_precision.disable_mixed_precision_graph_rewrite()
@testing_utils.disable_v2_dtype_behavior
def test_v1_dtype_behavior(self):
# These policies are allowed with V1 dtype behavior
with mp_policy.policy_scope(mp_policy.Policy('infer')):
pass
with mp_policy.policy_scope(mp_policy.Policy('infer_float32_vars')):
pass
# These policies are not allowed with V1 dtype behavior
with self.assertRaisesRegexp(
ValueError,
'global policy can only be set to a non-infer policy in TensorFlow 2'):
with mp_policy.policy_scope(mp_policy.Policy('float32')):
pass
with self.assertRaisesRegexp(
ValueError,
'global policy can only be set to a non-infer policy in TensorFlow 2'):
with mp_policy.policy_scope(
mp_policy.Policy('float16_with_float32_vars')):
pass
if __name__ == '__main__':
test.main()