blob: 5338725f88d23da784e9cddfe4366a4c3b1da0d0 [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
class DefFunctionTest(test.TestCase):
def testBasic(self):
def fn(x, a):
return x + a
func = def_function.function(fn, experimental_compile=False)
xla_func = def_function.function(fn, experimental_compile=True)
inputs = constant_op.constant([1, 2, 2, 3, 3])
self.assertAllClose([2, 3, 3, 4, 4], func(inputs, 1))
if not test.is_built_with_rocm():
# XLA support is not yet enabled for TF ROCm
self.assertAllClose([2, 3, 3, 4, 4], xla_func(inputs, 1))
def testUnsupportedOps(self):
def fn(x):
return array_ops.unique(x).y # Unique is not supported by XLA
func = def_function.function(fn, experimental_compile=False)
xla_func = def_function.function(fn, experimental_compile=True)
inputs = constant_op.constant([1, 2, 2, 3, 3])
self.assertAllClose([1, 2, 3], func(inputs))
with self.assertRaisesRegexp(errors.InvalidArgumentError, 'not compilable'):
xla_func(inputs)
def testFunctionGradient(self):
v = resource_variable_ops.ResourceVariable(2.0)
def fn(x):
return v * x
func = def_function.function(fn, experimental_compile=False)
xla_func = def_function.function(fn, experimental_compile=True)
def run_and_check(test_func):
x = constant_op.constant(3.0)
with backprop.GradientTape() as tape:
y = test_func(x)
dy = tape.gradient(y, v)
self.assertAllClose(6.0, y)
self.assertAllClose(3.0, dy)
run_and_check(func)
if not test.is_built_with_rocm():
# XLA support is not yet enabled for TF ROCm
run_and_check(xla_func)
def testControlFlow(self):
@def_function.function(experimental_compile=True)
def f(x):
assert control_flow_util.GraphOrParentsInXlaContext(
ops.get_default_graph())
x = ops.convert_to_tensor(x)
def body(i, a):
return i + 1, control_flow_ops.cond(i > 2, lambda: a + (x**2),
lambda: a + 3)
return control_flow_ops.while_loop(
lambda i, *_: i < 10,
body, (constant_op.constant(0), constant_op.constant(3.)),
maximum_iterations=10)[1]
@def_function.function(experimental_compile=True)
def g(x):
x = ops.convert_to_tensor(x)
with backprop.GradientTape() as tape:
tape.watch(x)
y = f(x)
return y, tape.gradient(y, x)
self.assertAllClose(40.0, f(2.0))
self.assertAllClose([40.0, 28.0], g(2.0))
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()