blob: 60cb1ca0ee9f3eedcb8e85ad8dd3e20592af5fd3 [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for AutoCastVariable."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
import numpy as np
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training.tracking import util as trackable_utils
TESTCASES = ({
'testcase_name': 'base',
'distribute': False
}, {
'testcase_name': 'distribute',
'distribute': True
})
def get_distribute_scope(distribute):
class DummyContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
if distribute:
return mirrored_strategy.MirroredStrategy(['cpu:0']).scope()
else:
return DummyContextManager()
def get_autocast_var(var, distribute):
if distribute:
return autocast_variable.AutoCastDistributedVariable(var)
else:
return autocast_variable.AutoCastVariable(var)
def get_var(val, dtype, name=None):
return variables.VariableV1(val, use_resource=True, dtype=dtype, name=name)
@test_util.run_all_in_graph_and_eager_modes
class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(*TESTCASES)
def test_read(self, distribute):
with get_distribute_scope(distribute):
x = get_var(1., dtypes.float32)
x = get_autocast_var(x, distribute)
self.evaluate(x.initializer)
# outside of auto cast scope.
self.assertEqual(x.dtype, dtypes.float32)
self.assertEqual(x.value().dtype, dtypes.float32)
self.assertEqual(x.read_value().dtype, dtypes.float32)
self.assertEqual(array_ops.identity(x).dtype, dtypes.float32)
# within auto cast scope of different dtype
with ops.get_default_graph()._enable_auto_casting_variables(
dtypes.float16):
self.assertEqual(x.dtype, dtypes.float16)
self.assertEqual(x.value().dtype, dtypes.float16)
self.assertEqual(x.read_value().dtype, dtypes.float16)
self.assertEqual(array_ops.identity(x).dtype, dtypes.float16)
# within auto cast scope of same dtype
with ops.get_default_graph()._enable_auto_casting_variables(
dtypes.float32):
self.assertEqual(x.dtype, dtypes.float32)
self.assertEqual(x.value().dtype, dtypes.float32)
self.assertEqual(x.read_value().dtype, dtypes.float32)
self.assertEqual(array_ops.identity(x).dtype, dtypes.float32)
@parameterized.named_parameters(*TESTCASES)
def test_read_nested_scopes(self, distribute):
with get_distribute_scope(distribute):
x = get_var(1., dtypes.float32)
x = get_autocast_var(x, distribute)
self.evaluate(x.initializer)
with ops.get_default_graph()._enable_auto_casting_variables(
dtypes.float16):
self.assertEqual(x.dtype, dtypes.float16)
self.assertEqual(x.read_value().dtype, dtypes.float16)
with ops.get_default_graph()._enable_auto_casting_variables(
dtypes.float32):
self.assertEqual(x.dtype, dtypes.float32)
self.assertEqual(x.read_value().dtype, dtypes.float32)
self.assertEqual(x.dtype, dtypes.float16)
self.assertEqual(x.read_value().dtype, dtypes.float16)
@parameterized.named_parameters(*TESTCASES)
def test_operator_overloads(self, distribute):
with get_distribute_scope(distribute):
for read_dtype in (dtypes.float32, dtypes.float16):
x = get_var(7., dtypes.float32)
x = get_autocast_var(x, distribute)
with ops.get_default_graph()._enable_auto_casting_variables(
read_dtype):
self.evaluate(x.initializer)
self.assertAlmostEqual(8, self.evaluate(x + 1))
self.assertAlmostEqual(10, self.evaluate(3 + x))
self.assertAlmostEqual(14, self.evaluate(x + x))
self.assertAlmostEqual(5, self.evaluate(x - 2))
self.assertAlmostEqual(6, self.evaluate(13 - x))
self.assertAlmostEqual(0, self.evaluate(x - x))
self.assertAlmostEqual(14, self.evaluate(x * 2))
self.assertAlmostEqual(21, self.evaluate(3 * x))
self.assertAlmostEqual(49, self.evaluate(x * x))
self.assertAlmostEqual(3.5, self.evaluate(x / 2))
self.assertAlmostEqual(1.5, self.evaluate(10.5 / x))
self.assertAlmostEqual(3, self.evaluate(x // 2))
self.assertAlmostEqual(2, self.evaluate(15 // x))
if read_dtype == dtypes.float32:
# The "mod" operator does not support float16
self.assertAlmostEqual(1, self.evaluate(x % 2))
self.assertAlmostEqual(2, self.evaluate(16 % x))
self.assertTrue(self.evaluate(x < 12))
self.assertTrue(self.evaluate(x <= 12))
self.assertFalse(self.evaluate(x > 12))
self.assertFalse(self.evaluate(x >= 12))
self.assertFalse(self.evaluate(12 < x))
self.assertFalse(self.evaluate(12 <= x))
self.assertTrue(self.evaluate(12 > x))
self.assertTrue(self.evaluate(12 >= x))
self.assertAlmostEqual(343, self.evaluate(pow(x, 3)), places=4)
self.assertAlmostEqual(128, self.evaluate(pow(2, x)), places=4)
self.assertAlmostEqual(-7, self.evaluate(-x))
self.assertAlmostEqual(7, self.evaluate(abs(x)))
x = get_var([7, 8, 9], dtypes.float32)
x = get_autocast_var(x, distribute)
self.evaluate(x.initializer)
self.assertEqual(self.evaluate(x[1]), 8)
@parameterized.named_parameters(*TESTCASES)
def test_assign(self, distribute):
with get_distribute_scope(distribute):
x = get_var(0., dtypes.float32)
x = get_autocast_var(x, distribute)
self.evaluate(x.initializer)
# outside of auto cast scope.
v1 = constant_op.constant(3.14, dtype=dtypes.float32)
v2 = constant_op.constant(3.14, dtype=dtypes.float16)
def run_and_check():
# Assign float32 values
self.assertAllClose(3.14, self.evaluate(x.assign(v1)))
self.assertAllClose(3.14 * 2, self.evaluate(x.assign_add(v1)))
self.assertAllClose(3.14, self.evaluate(x.assign_sub(v1)))
# Attempt to assign float16 values
with self.assertRaisesRegexp(
ValueError,
'conversion requested dtype float32 for Tensor with dtype float16'):
self.evaluate(x.assign(v2))
with self.assertRaisesRegexp(
ValueError,
'conversion requested dtype float32 for Tensor with dtype float16'):
self.evaluate(x.assign_add(v2))
with self.assertRaisesRegexp(
ValueError,
'conversion requested dtype float32 for Tensor with dtype float16'):
self.evaluate(x.assign_sub(v2))
# Assign Python floats
self.assertAllClose(3.14, self.evaluate(x.assign(3.14)))
self.assertAllClose(3.14 * 2, self.evaluate(x.assign_add(3.14)))
self.assertAllClose(3.14, self.evaluate(x.assign_sub(3.14)))
run_and_check()
# reset x
self.evaluate(x.assign(0.))
# within auto cast scope.
with ops.get_default_graph()._enable_auto_casting_variables(
dtypes.float16):
# assign still expect float32 value even if in float16 scope
run_and_check()
@parameterized.named_parameters(*TESTCASES)
def test_assign_stays_in_true_dtype(self, distribute):
with get_distribute_scope(distribute):
x = get_var(1., dtypes.float32)
x = get_autocast_var(x, distribute)
self.evaluate(x.initializer)
# small_val is a value such that 1.0 + small_val == 1.0 in fp16, but not
# in fp32
small_val = np.finfo('float16').eps / 2
small_tensor = constant_op.constant(small_val, dtype=dtypes.float32)
with ops.get_default_graph()._enable_auto_casting_variables(
dtypes.float16):
# Variable should be increased, despite it appearing to be the same
# float16 value.
self.assertEqual(1. + small_val,
self.evaluate(x.assign(1. + small_tensor)))
self.assertEqual(1., self.evaluate(x.value()))
self.assertEqual(1. + small_val, self.evaluate(x.value()))
self.evaluate(x.assign(1.))
with ops.get_default_graph()._enable_auto_casting_variables(
dtypes.float16):
self.assertEqual(1. + small_val,
self.evaluate(x.assign_add(small_tensor)))
self.assertEqual(1., self.evaluate(x.value()))
self.assertEqual(1. + small_val, self.evaluate(x.value()))
@parameterized.named_parameters(*TESTCASES)
def test_checkpoint(self, distribute):
with get_distribute_scope(distribute):
x = get_var(1., dtypes.float32)
x = get_autocast_var(x, distribute)
self.evaluate(x.initializer)
self.evaluate(x.assign(123.))
checkpoint = trackable_utils.Checkpoint(x=x)
prefix = os.path.join(self.get_temp_dir(), 'ckpt')
save_path = checkpoint.save(prefix)
self.evaluate(x.assign(234.))
checkpoint.restore(save_path).assert_consumed().run_restore_ops()
self.assertEqual(self.evaluate(x), 123.)
@parameterized.named_parameters(*TESTCASES)
def test_invalid_wrapped_variable(self, distribute):
with get_distribute_scope(distribute):
# Wrap a non-variable
with self.assertRaisesRegexp(ValueError, 'variable must be of type'):
x = constant_op.constant([1.], dtype=dtypes.float32)
get_autocast_var(x, distribute)
# Wrap a non-floating point variable
with self.assertRaisesRegexp(ValueError,
'variable must be a floating point'):
x = get_var(1, dtypes.int32)
get_autocast_var(x, distribute)
if distribute:
# Wrap a non-distributed variable with AutoCastDistributedVariable
with self.assertRaisesRegexp(ValueError, 'variable must be of type'):
x = get_var(1., dtypes.float32)
get_autocast_var(x, distribute)
def test_repr(self):
# We do not test with DistributionStrategy because we do not want to rely on
# the exact __repr__ output of a DistributedVariable.
x = get_var(1., dtypes.float32, name='x')
x = get_autocast_var(x, distribute=False)
if context.executing_eagerly():
self.assertStartsWith(
repr(x),
"<AutoCastVariable 'x:0' shape=() dtype=float32 true_dtype=float32, "
"numpy="
)
with ops.get_default_graph()._enable_auto_casting_variables(
dtypes.float16):
self.assertStartsWith(
repr(x),
"<AutoCastVariable 'x:0' shape=() dtype=float16 "
"true_dtype=float32, numpy="
)
else:
self.assertEqual(
repr(x),
"<AutoCastVariable 'x:0' shape=() dtype=float32 true_dtype=float32>"
)
with ops.get_default_graph()._enable_auto_casting_variables(
dtypes.float16):
self.assertEqual(
repr(x),
"<AutoCastVariable 'x:0' shape=() dtype=float16 true_dtype=float32>"
)
if __name__ == '__main__':
test.main()