blob: 7f679a4d02337563a2d7b81b704aebaaf79d0eb3 [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Tests for functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import time
import numpy as np
from tensorflow.core.framework import function_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import function
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.framework.errors import InvalidArgumentError
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_logging_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
def _OptimizerOptions():
for cse in [False, True]:
for inline in [False, True]:
for cfold in [False, True]:
cfg = config_pb2.ConfigProto(
graph_options=config_pb2.GraphOptions(
optimizer_options=config_pb2.OptimizerOptions(
opt_level=config_pb2.OptimizerOptions.L0,
do_common_subexpression_elimination=cse,
do_function_inlining=inline,
do_constant_folding=cfold)))
if cse:
cfg.graph_options.rewrite_options.arithmetic_optimization = (
rewriter_config_pb2.RewriterConfig.ON)
else:
cfg.graph_options.rewrite_options.arithmetic_optimization = (
rewriter_config_pb2.RewriterConfig.OFF)
if inline:
cfg.graph_options.rewrite_options.function_optimization = (
rewriter_config_pb2.RewriterConfig.ON)
else:
cfg.graph_options.rewrite_options.function_optimization = (
rewriter_config_pb2.RewriterConfig.OFF)
if cfold:
cfg.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.ON)
else:
cfg.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.OFF)
yield cfg
class FunctionTest(test.TestCase):
"""Test methods for verifying Function support.
These test methods are used as mix-ins in two test cases: with
and without C API support.
"""
def testIdentity(self):
@function.Defun(dtypes.float32, func_name="MyIdentity")
def MyIdentityFunc(a):
return a
with ops.Graph().as_default():
call = MyIdentityFunc([18.0])
self.assertEqual("MyIdentity", call.op.name)
with session.Session() as sess:
self.assertAllEqual([18.0], self.evaluate(call))
@test_util.run_v1_only("b/120545219")
def testIdentityImplicitDeref(self):
@function.Defun(dtypes.float32, func_name="MyIdentity")
def MyIdentityFunc(a):
return a
with ops.Graph().as_default():
var = variables.VariableV1([18.0])
call = MyIdentityFunc(var._ref()) # pylint: disable=protected-access
self.assertEqual("MyIdentity", call.op.name)
for cfg in _OptimizerOptions():
with session.Session(config=cfg) as sess:
self.evaluate(var.initializer)
self.assertAllEqual([18.0], self.evaluate(call))
def testIdentityOutputName(self):
@function.Defun(
dtypes.float32, func_name="MyIdentity", out_names=["my_result_name"])
def MyIdentityFunc(a):
return a
with ops.Graph().as_default():
call = MyIdentityFunc([18.0])
self.assertEqual("MyIdentity", call.op.name)
with session.Session() as sess:
self.assertAllEqual([18.0], self.evaluate(call))
def testTooManyOutputNames(self):
@function.Defun(
dtypes.float32,
func_name="MyIdentity",
out_names=["my_result1", "my_result2"])
def MyIdentityFunc(a):
return a
with ops.Graph().as_default():
with self.assertRaisesRegexp(
errors_impl.InvalidArgumentError,
(r"output names must be either empty or equal in size to outputs. "
"output names size = 2 outputs size = 1")):
MyIdentityFunc([18.0])
def testDefineFunction2Args(self):
@function.Defun(dtypes.float32, dtypes.float32, func_name="APlus2B")
def APlus2B(a, b):
return a + b * 2
with ops.Graph().as_default():
call = APlus2B([1.0], [2.0])
self.assertEqual("APlus2B", call.op.name)
with session.Session() as sess:
self.assertAllEqual([5.0], self.evaluate(call))
def testFunctionWithNoOutput(self):
@function.Defun(dtypes.float32, dtypes.float32)
def APlus2B(a, b):
c = a + b * 2 # Create some ops to have nodes in the body
print(c) # Using 'print' to make lint happy
with ops.Graph().as_default():
# Call function. There should be no exceptions.
APlus2B([1.0], [2.0])
def testDefineFunction2ArgsOutputName(self):
@function.Defun(
dtypes.float32,
dtypes.float32,
func_name="APlus2B",
out_names=["my_result_name"])
def APlus2B(a, b):
return a + b * 2
# APlus2B is stateless.
self.assertEqual([], APlus2B.stateful_ops)
with ops.Graph().as_default():
call = APlus2B([1.0], [2.0])
self.assertEqual("APlus2B", call.op.name)
with session.Session() as sess:
self.assertAllEqual([5.0], self.evaluate(call))
def testDefineFunctionDuplicateOutputs(self):
@function.Defun(dtypes.float32, func_name="Duplicate")
def Duplicate(a):
b = a + 1.0
return b, b
g = ops.Graph()
with g.as_default():
Duplicate([3.0])
func_sig = g.as_graph_def().library.function[0].signature
# The names given to both outputs should be different
# even though the same tensor is emitted to both.
out_names = [a.name for a in func_sig.output_arg]
self.assertEqual(2, len(out_names))
self.assertNotEqual(out_names[0], out_names[1])
def testGradientFunc(self):
@function.Defun(dtypes.float32, func_name="XSquarePlusOneFn")
def XSquarePlusOne(x):
return x * x + 1.0
@function.Defun(dtypes.float32, dtypes.float32)
def XSquarePlusOneGrad(x, dy):
dx = functional_ops.symbolic_gradient(
input=[x, dy], Tout=[dtypes.float32], f="XSquarePlusOneFn", name="dx")
return dx
g = ops.Graph()
with g.as_default():
call_f = XSquarePlusOne([2.0])
call_g = XSquarePlusOneGrad([2.0], [0.1])
with session.Session() as sess:
self.assertAllClose([5.0], self.evaluate(call_f))
self.assertAllClose([0.4], self.evaluate(call_g))
def testTanhSymGrad(self):
@function.Defun(dtypes.float32)
def Forward(x):
return math_ops.reduce_sum(math_ops.tanh(x))
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtypes.float32)
y = Forward(x)
dx = gradients_impl.gradients([y], [x])
inp = np.array([-1, 1, 2, -2], dtype=np.float32)
feed = {x: inp}
cfg = config_pb2.ConfigProto(
graph_options=config_pb2.GraphOptions(
optimizer_options=config_pb2.OptimizerOptions(
opt_level=config_pb2.OptimizerOptions.L1,
do_function_inlining=True)))
with session.Session(graph=g, config=cfg) as sess:
out, = sess.run(dx, feed)
self.assertAllClose(1 - np.square(np.tanh(inp)), out)
def testCustomGradient(self):
dtype = dtypes.float32
@function.Defun(dtype, dtype, dtype)
def XentLossGrad(logits, labels, dloss):
dlogits = array_ops.reshape(dloss, [-1, 1]) * (
nn_ops.softmax(logits) - labels)
dlabels = array_ops.zeros_like(labels)
# Takes exp(dlogits) to differentiate it from the "correct" gradient.
return math_ops.exp(dlogits), dlabels
@function.Defun(dtype, dtype, grad_func=XentLossGrad)
def XentLoss(logits, labels):
return math_ops.reduce_sum(labels * math_ops.log(nn_ops.softmax(logits)),
1)
g = ops.Graph()
with g.as_default():
logits = array_ops.placeholder(dtype)
labels = array_ops.placeholder(dtype)
loss = XentLoss(logits, labels)
dlogits = gradients_impl.gradients([loss], [logits])
x = np.random.uniform(-10., 10., size=(4, 9)).astype(np.float32)
prob = np.exp(x) / np.sum(np.exp(x), 1, keepdims=1)
y = np.random.uniform(-10., 10., size=(4, 9)).astype(np.float32)
for cfg in _OptimizerOptions():
tf_logging.info("cfg = %s", cfg)
with session.Session(graph=g, config=cfg) as sess:
out, = sess.run(dlogits, {logits: x, labels: y})
self.assertAllClose(out, np.exp(prob - y))
@test_util.disable_xla("b/124286351") # No error is raised
def testCustomGradientError(self):
dtype = dtypes.float32
@function.Defun(dtype, dtype, dtype)
def Grad(x, dy, dz):
# Should have returned 1 result.
return x, dy + dz
@function.Defun(dtype, grad_func=Grad)
def Forward(x):
return x, x
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(dtype)
out = math_ops.add_n(Forward(inp))
dinp = gradients_impl.gradients(out, [inp])
x = np.random.uniform(-10., 10., size=(4, 9)).astype(np.float32)
with session.Session(graph=g) as sess:
with self.assertRaisesRegexp(
errors_impl.InvalidArgumentError,
"SymGrad expects to return 1.*but get 2.*instead"):
_ = sess.run(dinp, {inp: x})
def testSymGradShape(self):
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtypes.float32, [25, 4])
y = array_ops.placeholder(dtypes.float32, [200, 100])
dz = array_ops.placeholder(dtypes.float32, [1])
# We assume Foo is a function of (x, y) -> (z) Then, Foo's
# gradient function is (x, y, dz) -> (dx, dy). dx's shape
# should be the same as x's; and dy's shape should be the same
# as y's.
dx, dy = functional_ops.symbolic_gradient(
input=[x, y, dz], Tout=[dtypes.float32] * 2, f="Foo")
self.assertEqual(x.get_shape(), dx.get_shape())
self.assertEqual(y.get_shape(), dy.get_shape())
@test_util.run_deprecated_v1
def testSymGradAttr(self):
@function.Defun(noinline=True)
def Foo(x):
return x * 2
self.assertTrue(
Foo.instantiate([dtypes.float32]).definition.attr["_noinline"].b)
g = ops.Graph()
with g.as_default():
x = constant_op.constant(3.0)
y = Foo(x)
dx, = gradients_impl.gradients(y, [x])
cfg = config_pb2.ConfigProto(
graph_options=config_pb2.GraphOptions(
optimizer_options=config_pb2.OptimizerOptions(
opt_level=config_pb2.OptimizerOptions.L0,
do_common_subexpression_elimination=True,
do_function_inlining=True,
do_constant_folding=True)))
with self.session(graph=g, config=cfg):
self.assertAllClose(y.eval(), 6.)
self.assertAllClose(dx.eval(), 2.)
def _testZNoDepOnY(self, use_const_grad_ys):
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(x, y): # pylint: disable=unused-argument
return x * 2
with ops.Graph().as_default():
# z = Foo(x, y). z doe
x = constant_op.constant(1.0)
y = constant_op.constant(2.0)
z = Foo(x, y)
if use_const_grad_ys:
dx, dy = gradients_impl.gradients([z], [x, y], grad_ys=[1.0])
else:
dx, dy = gradients_impl.gradients([z], [x, y])
with session.Session() as sess:
dx_val, dy_val = self.evaluate([dx, dy])
self.assertEqual([2.0], dx_val)
self.assertEqual([0.0], dy_val)
def testZNoDepOnY(self):
self._testZNoDepOnY(False)
def testZNoDepOnYConstGradYs(self):
# Tests for constant folding of grad_ys
self._testZNoDepOnY(True)
def testDefineFunctionNoArgs(self):
@function.Defun(func_name="AConstant")
def AConstant():
return constant_op.constant([42])
with ops.Graph().as_default():
call = AConstant()
self.assertEqual("AConstant", call.op.name)
with session.Session() as sess:
self.assertAllEqual([42], self.evaluate(call))
def testDefineFunctionNames(self):
@function.Defun(dtypes.float32, func_name="Foo")
def Foo(a):
return a + 1
with ops.Graph().as_default():
call1 = Foo([1.0])
self.assertEqual("Foo", call1.op.name)
call2 = Foo([1.0])
self.assertEqual("Foo_1", call2.op.name)
# pylint: disable=unexpected-keyword-arg
call3 = Foo([1.0], name="mine")
self.assertEqual("mine", call3.op.name)
with ops.name_scope("my"):
call4 = Foo([1.0], name="precious")
self.assertEqual("my/precious", call4.op.name)
def testNoOp(self):
@function.Defun(dtypes.float32)
def Foo(x):
y = logging_ops.Print(x, [], "Hello")
with ops.control_dependencies([y]):
z = control_flow_ops.no_op()
with ops.control_dependencies([z]):
return x * 2
with ops.Graph().as_default(), self.cached_session():
z = Foo(constant_op.constant(3.0))
self.assertAllEqual(z.eval(), 6.0)
def testAssertOp(self):
@function.Defun(dtypes.float32)
def Foo(x):
check = gen_logging_ops._assert(math_ops.greater(x, 0), [x])
with ops.control_dependencies([check]):
return x * 2
# Foo contains a stateful op (Assert).
self.assertEqual([("Assert", "Assert")], Foo.stateful_ops)
g = ops.Graph()
with g.as_default(), self.cached_session():
self.assertAllEqual(Foo(constant_op.constant(3.0)).eval(), 6.0)
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"assertion failed.*-3"):
self.assertAllEqual(Foo(constant_op.constant(-3.0)).eval(), 6.0)
@test_util.run_deprecated_v1
def testAssertWrapper(self):
@function.Defun(dtypes.float32)
def MyFn(x):
with ops.control_dependencies(
[control_flow_ops.Assert(math_ops.less_equal(x, 10.0), [x])]):
return array_ops.identity(x)
with self.cached_session():
self.assertEqual(1.0, MyFn(1.0).eval())
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"assertion"):
_ = MyFn(100.0).eval()
@test_util.run_deprecated_v1
def testWhileLoopCallsFunc(self):
with self.session(use_gpu=True) as sess:
@function.Defun(dtypes.float32)
def Times2(x):
constant_two = constant_op.constant(2, dtypes.int32)
two_on_gpu = math_ops.cast(constant_two, dtypes.float32)
return x * two_on_gpu
def Body(x):
x2 = Times2(x)
x2.set_shape([])
return x2
loop = control_flow_ops.while_loop(lambda x: x < 1e5, Body, [1.0])
ans = self.evaluate(loop)
self.assertAllClose(ans, 131072.)
@test_util.run_deprecated_v1
def testControlFlowStrictness(self):
"""Inlined functions must not execute in a untaken control flow branch."""
@function.Defun(dtypes.int32)
def AssertFail(x):
# Assertion that always fails and does not have a data dependency on `x`.
assert_false = control_flow_ops.Assert(False, [42])
with ops.control_dependencies([assert_false]):
return array_ops.identity(x)
with ops.device("CPU"):
pred = array_ops.placeholder(dtypes.bool)
x = array_ops.placeholder(dtypes.int32)
cond = control_flow_ops.cond(pred, lambda: x + 1, lambda: AssertFail(x))
# pylint: disable=unnecessary-lambda
loop = control_flow_ops.while_loop(lambda y: pred,
lambda y: AssertFail(y), [x])
# pylint: enable=unnecessary-lambda
rewriter_config = rewriter_config_pb2.RewriterConfig(
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
# Enables inlining.
config = config_pb2.ConfigProto(
graph_options=config_pb2.GraphOptions(
optimizer_options=config_pb2.OptimizerOptions(
opt_level=config_pb2.OptimizerOptions.L0,
do_common_subexpression_elimination=True,
do_function_inlining=True,
do_constant_folding=True),
rewrite_options=rewriter_config))
with session.Session(config=config) as sess:
# Since the 'False' branch is not taken, the assertion should not fire.
self.assertEqual(4, sess.run(cond, {pred: True, x: 3}))
# The assertion should still fire if the False branch is taken.
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"assertion"):
sess.run(cond, {pred: False, x: 3})
# Similarly for loops.
self.assertEqual(3, sess.run(loop, {pred: False, x: 3}))
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"assertion"):
sess.run(loop, {pred: True, x: 3})
@test_util.run_deprecated_v1
def testVar(self):
@function.Defun(dtypes.float32)
def Foo(x):
return x * x + 1
g = ops.Graph()
with g.as_default():
v = variables.Variable(constant_op.constant(10.0))
z = Foo(v)
with self.session(graph=g):
variables.global_variables_initializer().run()
self.assertAllEqual(z.eval(), 101.)
@test_util.run_deprecated_v1
def testResourceVarAsImplicitInput(self):
g = ops.Graph()
with g.as_default(), ops.device("cpu:0"):
expected_type = dtypes.float32
expected_shape = tensor_shape.TensorShape((4, 4))
v = variable_scope.get_variable(
"var", expected_shape, expected_type, use_resource=True)
@function.Defun()
def Foo():
captured = array_ops.identity(v)
self.assertEqual(expected_type, captured.dtype)
self.assertEqual(expected_shape, captured.shape)
return captured, array_ops.shape(captured)
expected_val = v.value()
actual_val, actual_shape = Foo()
with self.session(graph=g):
v.initializer.run()
self.assertAllEqual(expected_val.eval(), self.evaluate(actual_val))
self.assertAllEqual(expected_shape, self.evaluate(actual_shape))
def testDefineErrors(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError, "can not return None"):
@function.Defun()
def TwoNone():
return None, None
_ = TwoNone.definition
with self.assertRaisesRegexp(ValueError, "are not supported"):
@function.Defun()
def DefaultArg(unused_a=12):
return constant_op.constant([1])
_ = DefaultArg.definition
with self.assertRaisesRegexp(ValueError, "are not supported"):
@function.Defun()
def KwArgs(**unused_kwargs):
return constant_op.constant([1])
_ = KwArgs.definition
with self.assertRaisesRegexp(ValueError, "specified input types"):
@function.Defun(dtypes.float32)
def PlusMinusV2(a, b):
return a + b, b - a
_ = PlusMinusV2.definition
with self.assertRaisesRegexp(ValueError, "specified input types"):
@function.Defun(dtypes.float32, dtypes.float32, dtypes.float32)
def PlusMinusV3(a, b):
return a + b, b - a
_ = PlusMinusV3.definition
def testCallErrors(self):
@function.Defun()
def Const():
return constant_op.constant(1)
@function.Defun(dtypes.int32)
def PlusOne(a):
return a + 1
@function.Defun(dtypes.int32, dtypes.int32)
def PlusMinus(a, b):
return a + b, b - a
with ops.Graph().as_default():
_ = Const()
# pylint: disable=too-many-function-args
# pylint: disable=unexpected-keyword-arg
# pylint: disable=no-value-for-parameter
with self.assertRaisesRegexp(ValueError, "arguments: 0"):
_ = Const(1)
with self.assertRaisesRegexp(ValueError, "arguments: 0"):
_ = Const(1, 2)
with self.assertRaisesRegexp(ValueError, "arguments: 1"):
_ = PlusOne()
_ = PlusOne(1)
with self.assertRaisesRegexp(ValueError, "arguments: 1"):
_ = PlusOne(1, 2)
with self.assertRaisesRegexp(ValueError, "arguments: 2"):
_ = PlusMinus()
with self.assertRaisesRegexp(ValueError, "arguments: 2"):
_ = PlusMinus(1)
_ = PlusMinus(1, 2)
_ = PlusOne(1, name="p1")
with self.assertRaisesRegexp(ValueError, "Unknown keyword arguments"):
_ = PlusOne(1, device="/device:GPU:0")
def testFunctionDecorator(self):
@function.Defun(dtypes.float32, func_name="Minus1")
def Minus1(b):
return b - 1.0
with ops.Graph().as_default():
call1 = Minus1([2.])
self.assertTrue(isinstance(Minus1, function._DefinedFunction))
self.assertEqual(Minus1.name, "Minus1")
# pylint: disable=unexpected-keyword-arg
call2 = Minus1(call1, name="next")
# pylint: enable=unexpected-keyword-arg
self.assertEqual("next", call2.op.name)
with session.Session() as sess:
self.assertAllEqual([1], self.evaluate(call1))
self.assertAllEqual([0], self.evaluate(call2))
def testNestedFunction(self):
@function.Defun(dtypes.float32)
def Cube(x):
return x * x * x
@function.Defun(dtypes.float32, dtypes.float32)
def CubeXPlusY(x, y):
return Cube(x) + y
with ops.Graph().as_default():
z = CubeXPlusY(3.0, -2.0)
with self.cached_session():
self.assertAllEqual(z.eval(), 25.0)
def testNestedDefinedFunction(self):
@function.Defun(dtypes.float32, dtypes.float32)
def CubeXPlusY(x, y):
@function.Defun(dtypes.float32)
def Cube(x):
return x * x * x
return Cube(x) + y
with ops.Graph().as_default():
z = CubeXPlusY(3.0, -2.0)
with self.cached_session():
self.assertAllEqual(z.eval(), 25.0)
def testUnusedFunction(self):
invoked = False
# pylint: disable=unused-variable
@function.Defun()
def Unused():
invoked = True
return constant_op.constant(42.)
self.assertFalse(invoked)
g = ops.Graph()
with g.as_default():
@function.Defun()
def Unused2():
invoked = True
return constant_op.constant(7.)
constant_op.constant(3.)
# pylint: enable=unused-variable
self.assertFalse(invoked)
gdef = g.as_graph_def()
self.assertEqual(0, len(gdef.library.function))
@test_util.run_deprecated_v1
def testReduction(self):
g = ops.Graph()
# BN0 is computing batch normed matrix along rows.
def BN0(x):
mean = math_ops.reduce_mean(x, [0])
var = math_ops.reduce_mean(math_ops.square(x - mean)) # biased var
rstd = math_ops.rsqrt(var + 1e-8)
return (x - mean) * rstd
# Wraps BatchNorm in a tf function.
@function.Defun(dtypes.float32)
def BN1(x):
return BN0(x)
with g.as_default():
x = array_ops.placeholder(dtypes.float32)
y0 = BN0(x) # A plain graph
y1 = BN1(x) # A tf function
dx0, = gradients_impl.gradients([y0], [x])
dx1, = gradients_impl.gradients([y1], [x])
# Both should produce the same result and gradient.
with self.session(graph=g) as sess:
vals = sess.run([y0, y1, dx0, dx1], {x: np.random.uniform(size=(3, 7))})
self.assertAllClose(vals[0], vals[1])
self.assertAllClose(vals[2], vals[3])
@test_util.run_deprecated_v1
def testCapture(self):
g = ops.Graph()
with g.as_default():
w = variables.Variable(constant_op.constant([[1.0]]))
b = variables.Variable(constant_op.constant([2.0]))
# Foo() captures w and b.
@function.Defun(dtypes.float32)
def Foo(x):
# Plus() captures b.
@function.Defun(dtypes.float32)
def Plus(y):
return y + b
return Plus(math_ops.matmul(w, x))
y = Foo(constant_op.constant([[10.]]))
@function.Defun()
def Bar():
return w
z = Bar()
with self.session(graph=g):
variables.global_variables_initializer().run()
self.assertAllEqual(y.eval(), [[12.0]])
self.assertAllEqual(z.eval(), [[1.0]])
def testCaptureControls(self):
g = ops.Graph()
with g.as_default():
x = constant_op.constant([10.0])
x = logging_ops.Print(x, [x], "outer")
@function.Defun(dtypes.float32)
def Foo(y):
with ops.control_dependencies([x]):
y = logging_ops.Print(y, [y], "inner")
return y
with self.assertRaisesRegexp(ValueError, "not an element of this graph."):
# NOTE: We still do not support capturing control deps.
_ = Foo(x)
@test_util.run_deprecated_v1
def testCaptureInWhileLoop(self):
g = ops.Graph()
with g.as_default():
x = constant_op.constant(1)
@function.Defun()
def Foo():
return control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + x,
[0])
y = Foo()
with self.session(graph=g) as sess:
self.assertEqual(self.evaluate(y), 10)
@test_util.run_deprecated_v1
def testCaptureInCond(self):
g = ops.Graph()
with g.as_default():
x = constant_op.constant(1)
@function.Defun(dtypes.bool)
def Foo(pred):
return control_flow_ops.cond(pred, lambda: x, lambda: x + 1)
y = Foo(True)
z = Foo(False)
with self.session(graph=g) as sess:
self.assertEqual(self.evaluate(y), 1)
self.assertEqual(self.evaluate(z), 2)
@test_util.run_deprecated_v1
def testSignatureHash(self):
# Foo.Inner and Bar.Inner have identical function body but have
# different signatures. They should be treated as two different functions.
@function.Defun()
def Foo(x):
@function.Defun()
def Inner(x):
return x + 10.
return Inner(x)
@function.Defun()
def Bar(x):
@function.Defun()
def Inner(x, unused_y, unused_z):
return x + 10.
return Inner(x, 2., 3.)
g = ops.Graph()
with g.as_default():
x = constant_op.constant(10.0)
y = Foo(x)
z = Bar(x)
with self.session(graph=g) as sess:
v0, v1 = self.evaluate([y, z])
self.assertAllEqual(v0, 20.)
self.assertAllEqual(v1, 20.)
def testShapeFunction(self):
@function.Defun(
dtypes.float32, shape_func=lambda op: [op.inputs[0].get_shape()])
def Foo(x):
return x + 1.0
@function.Defun(
shape_func=lambda op: [[1] + op.inputs[0].get_shape().as_list()])
def Bar(x):
return array_ops.stack([x])
g = ops.Graph()
with g.as_default():
x = Foo([1.0, 2.0])
self.assertEqual(x.get_shape().as_list(), [2])
y = Bar(array_ops.zeros([1, 2, 3]))
self.assertAllEqual(y.get_shape().as_list(), [1, 1, 2, 3])
@test_util.run_deprecated_v1
def testVariableReuse(self):
def LinearWithReuse(input_tensor, reuse=None):
size = input_tensor.shape.dims[1]
with variable_scope.variable_scope("linear", reuse=reuse):
w = variable_scope.get_variable(
"w", shape=[size, size], dtype=input_tensor.dtype)
return math_ops.matmul(input_tensor, w)
@function.Defun(dtypes.float32)
def Foo(inputs):
inputs = array_ops.reshape(inputs, [32, 100])
hidden = LinearWithReuse(inputs)
return LinearWithReuse(hidden, reuse=True)
input_op = array_ops.placeholder(shape=[32, 100], dtype=dtypes.float32)
output_op = Foo(input_op)
global_vars = variables.global_variables()
self.assertEqual(len(global_vars), 1)
self.assertEqual(global_vars[0].name, "linear/w:0")
with session.Session() as sess:
self.evaluate(variables.global_variables_initializer())
output_val = sess.run(
output_op, feed_dict={input_op: np.random.rand(32, 100)})
self.assertEqual(output_val.shape, (32, 100))
@test_util.run_deprecated_v1
def testFunctionCallInDifferentVariableScopes(self):
@function.Defun(dtypes.float32)
def Foo(inputs):
var = variable_scope.get_variable(
"var",
shape=[10],
dtype=dtypes.float32,
initializer=init_ops.ones_initializer())
return inputs + var
input_op = array_ops.placeholder(shape=[10], dtype=dtypes.float32)
with variable_scope.variable_scope("vs1"):
out1_op = Foo(input_op)
with variable_scope.variable_scope("vs2"):
out2_op = Foo(input_op)
global_vars = variables.global_variables()
self.assertEqual(len(global_vars), 1)
self.assertEqual(global_vars[0].name, "vs1/var:0")
with session.Session() as sess:
self.evaluate(variables.global_variables_initializer())
out1, out2 = sess.run(
[out1_op, out2_op], feed_dict={input_op: np.linspace(1, 10, 10)})
self.assertAllEqual(out1, np.linspace(2, 11, 10))
self.assertAllEqual(out2, np.linspace(2, 11, 10))
def testTwoInputsSameOp(self):
g = ops.Graph()
with g.as_default():
m = array_ops.placeholder(dtypes.float32)
s, u, v = linalg_ops.svd(m)
ss = math_ops.reduce_sum(s)
uu = math_ops.reduce_sum(u)
vv = math_ops.reduce_sum(v)
result = ss + uu + vv
f = graph_to_function_def.graph_to_function_def(
g,
g.get_operations()[1:], # skip the placeholder
[s, u, v],
[result])
self.assertEqual(len(f.signature.input_arg), 3)
def testGradientWithIntegerFunctionArgument(self):
@function.Defun(dtypes.int32, dtypes.float32)
def Foo(t, x):
return x[t]
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(dtypes.float32)
t = constant_op.constant(0, dtypes.int32)
out = Foo(t, inp)
dinp, = gradients_impl.gradients(out, [inp])
x = np.zeros((2,)).astype(np.float32)
with session.Session(graph=g) as sess:
self.assertAllClose(
np.array([1.0, 0.0]).astype(np.float32), sess.run(dinp, {inp: x}))
@test_util.run_deprecated_v1
def testFunctionMarkedStateful(self):
@function.Defun(dtypes.int32, dtypes.float32)
def Foo(t, x):
return x[t]
@function.Defun(dtypes.int64)
def Bar(x):
return x
# NOTE(mrry): All functions are currently considered stateless by the
# runtime, so we simulate a "stateful" function.
# TODO(b/70565970): Remove this hack when we are able to build stateful
# functions using the API.
# pylint: disable=protected-access
Foo._signature.is_stateful = True
Bar._signature.is_stateful = True
# pylint: enable=protected-access
result_1 = Foo(3, [1.0, 2.0, 3.0, 4.0])
result_2 = Bar(constant_op.constant(100, dtype=dtypes.int64))
with session.Session() as sess:
self.assertEqual(4.0, self.evaluate(result_1))
self.assertEqual(100, self.evaluate(result_2))
self.assertEqual((4.0, 100), sess.run((result_1, result_2)))
@test_util.run_deprecated_v1
def testStatefulFunction(self):
@function.Defun()
def FunctionWithStatelessOp():
return constant_op.constant(42.0)
@function.Defun()
def FunctionWithStatefulOp():
return random_ops.random_uniform([100], maxval=10, dtype=dtypes.int32)
@function.Defun()
def FunctionWithStatelessFunctionCall():
return FunctionWithStatelessOp()
@function.Defun()
def FunctionWithStatefulFunctionCall():
return FunctionWithStatefulOp()
# Test that the `is_stateful` bit is propagated.
self.assertFalse(FunctionWithStatelessOp.definition.signature.is_stateful)
self.assertTrue(FunctionWithStatefulOp.definition.signature.is_stateful)
self.assertFalse(
FunctionWithStatelessFunctionCall.definition.signature.is_stateful)
self.assertTrue(
FunctionWithStatefulFunctionCall.definition.signature.is_stateful)
# Ensure that two invocations of the same random-number-generating
# function produce different results.
result1 = FunctionWithStatefulFunctionCall()
result2 = FunctionWithStatefulFunctionCall()
# Statefulness affects how the function is treated by the various
# optimization passes, so run the test in each optimizer
# configuration.
for config in _OptimizerOptions():
with session.Session(config=config) as sess:
val1, val2 = sess.run((result1, result2))
self.assertFalse(all(val1 == val2))
val3, val4 = sess.run((result1, result2))
self.assertFalse(all(val3 == val1))
self.assertFalse(all(val4 == val2))
@test_util.run_v1_only("currently failing on v2")
def testStatefulFunctionWithWhitelisting(self):
t = random_ops.random_uniform([100], maxval=10, dtype=dtypes.int32)
@function.Defun(capture_by_value=True)
def StatefulFn():
return t + constant_op.constant(3, dtype=dtypes.int32)
# First time we try to capture a stateful RandomUniform op.
with self.assertRaisesRegexp(ValueError, "Cannot capture a stateful node"):
res = StatefulFn()
# This time we whitelist this op, so that its recreated.
@function.Defun(capture_by_value=True, whitelisted_stateful_ops=set([t.op]))
def StatefulFn2():
return t + constant_op.constant(3, dtype=dtypes.int32)
res = StatefulFn2()
with session.Session() as sess:
r = sess.run(res)
for i in r:
self.assertGreaterEqual(i, 3)
@test_util.run_deprecated_v1
def testSameFunctionOnTwoDevices(self):
@function.Defun(dtypes.float32)
def AddOne(x):
return x + 1.0
with ops.device("/cpu:0"):
f_0 = AddOne(41.0)
with ops.device("/cpu:1"):
f_1 = AddOne(43.0)
for config in _OptimizerOptions():
config.device_count["CPU"] = 2
with session.Session(config=config) as sess:
self.assertEqual(42.0, self.evaluate(f_0))
self.assertEqual(44.0, self.evaluate(f_1))
self.assertEqual((42.0, 44.0), sess.run((f_0, f_1)))
@test_util.run_deprecated_v1
def testGuaranteedConstsAreCaptured(self):
var = variables.Variable(1.0)
const = array_ops.guarantee_const(var)
also_const = array_ops.identity(const)
still_const = array_ops.identity(also_const)
not_const = still_const + var
also_not_const = array_ops.placeholder(dtypes.float32)
@function.Defun()
def CapturesGuaranteedConst():
output = const + also_const + still_const + not_const + also_not_const
first, second, third, fourth, fifth = function.get_extra_args()
self.assertEqual("GuaranteeConst", first.consumers()[0].node_def.op)
self.assertEqual("GuaranteeConst", second.consumers()[0].node_def.op)
self.assertEqual("GuaranteeConst", third.consumers()[0].node_def.op)
self.assertNotEqual("GuaranteeConst", fourth.consumers()[0].node_def.op)
self.assertNotEqual("GuaranteeConst", fifth.consumers()[0].node_def.op)
return output
with self.session(use_gpu=False) as sess:
self.evaluate(var.initializer)
_ = sess.run(CapturesGuaranteedConst(), {also_not_const: 1.0})
@test_util.run_deprecated_v1
def testSameFunctionDifferentGrads(self):
def PartOne(x):
# Default grad is dx = dy * 2
@function.Defun(dtypes.float32)
def Foo(x):
return x * 2
return Foo(x)
def PartTwo(x):
@function.Defun(dtypes.float32, dtypes.float32)
def Bar(x, dy):
return x + dy # crazy backprop
@function.Defun(dtypes.float32, grad_func=Bar)
def Foo(x):
return x * 2
return Foo(x)
def PartThree(x):
def Bar(op, dy):
return op.inputs[0] * dy / 2 # crazy backprop
@function.Defun(dtypes.float32, python_grad_func=Bar)
def Foo(x):
return x * 2
return Foo(x)
g = ops.Graph()
with g.as_default():
x = constant_op.constant(100.)
x0 = x
y0 = PartOne(x0)
dx0, = gradients_impl.gradients(ys=[y0], xs=[x0])
x1 = x
y1 = PartTwo(x1)
dx1, = gradients_impl.gradients(ys=[y1], xs=[x1])
x2 = x
y2 = PartThree(x2)
dx2, = gradients_impl.gradients(ys=[y2], xs=[x2])
with self.session(graph=g) as sess:
v0, v1, v2 = self.evaluate([dx0, dx1, dx2])
self.assertAllEqual(v0, 2.)
self.assertAllEqual(v1, 101.)
self.assertAllEqual(v2, 50.)
class FunctionsFromProtos(test.TestCase):
def expectFunctionsEqual(self, func, grad_func=None, new_func=None):
if new_func is None:
# Make a copy of func.definition to avoid any bugs masked by using the
# same object
serialized_fdef = func.definition.SerializeToString()
# Serialize and then deserialize `func` to create `new_func`
fdef = function_pb2.FunctionDef.FromString(serialized_fdef)
new_func = function._from_definition(fdef, grad_func=grad_func)
self.assertEqual(func.name, new_func.name)
self.assertEqual(func.definition, new_func.definition)
self.assertEqual(func.grad_func_name, new_func.grad_func_name)
self.assertEqual(func.declared_input_types, new_func.declared_input_types)
self.assertEqual(func.captured_inputs, new_func.captured_inputs)
@test_util.run_deprecated_v1
def testBasic(self):
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(x, y):
return x + y
self.expectFunctionsEqual(Foo)
def testGradFunc(self):
@function.Defun(dtypes.float32, dtypes.float32)
def G(x, dy):
return x * dy
@function.Defun(dtypes.float32, grad_func=G)
def F(x):
return math_ops.exp(x) - math_ops.exp(-x)
self.expectFunctionsEqual(F, grad_func=G)
def testCapturedInputs(self):
c = constant_op.constant(10, dtypes.int64)
@function.Defun(dtypes.int64)
def Foo(x):
return x + c
new_func = function._from_definition(Foo.definition)
self.assertEqual(Foo.name, new_func.name)
self.assertEqual(Foo.definition, new_func.definition)
self.assertEqual(Foo.grad_func_name, new_func.grad_func_name)
# Captured inputs are added as regular inputs to the function definition
self.assertEqual(new_func.declared_input_types,
Foo.declared_input_types + (dtypes.int64,))
self.assertEqual(len(new_func.captured_inputs), 0)
def testNestedFunctions(self):
@function.Defun(dtypes.float32)
def Outer(x):
@function.Defun(dtypes.float32)
def Inner(y):
return y + 1
return Inner(Inner(x))
self.expectFunctionsEqual(Outer)
def testFromLibrary(self):
# Define some functions with different gradient functions. Note that many of
# the below functions are identical since function bodies don't matter for
# this test.
@function.Defun(dtypes.float32, dtypes.float32)
def G1(x, dy):
return x * dy
@function.Defun(dtypes.float32, dtypes.float32)
def G2(x, dy):
return x * dy
# F1 and F2 have the same gradient function
@function.Defun(dtypes.float32, grad_func=G1)
def F1(x):
return math_ops.exp(x) - math_ops.exp(-x)
@function.Defun(dtypes.float32, grad_func=G1)
def F2(x):
return math_ops.exp(x) - math_ops.exp(-x)
# F3 has a different gradient function
@function.Defun(dtypes.float32, grad_func=G2)
def F3(x):
return math_ops.exp(x) - math_ops.exp(-x)
# F4 has no gradient function
@function.Defun(dtypes.float32)
def F4(x):
return math_ops.exp(x) - math_ops.exp(-x)
# Instantiate all functions
g = ops.Graph()
with g.as_default():
c = constant_op.constant(1.0, dtypes.float32)
f1 = F1(c)
f2 = F2(c)
f3 = F3(c)
f4 = F4(c)
gradients_impl.gradients([f1, f2, f3, f4], c)
library = g.as_graph_def().library
new_funcs = function.from_library(library)
def CheckNewFunc(func):
new_func = [f for f in new_funcs if f.name == func.name]
self.assertEqual(len(new_func), 1)
self.expectFunctionsEqual(func, new_func=new_func[0])
CheckNewFunc(G1)
CheckNewFunc(G2)
CheckNewFunc(F1)
CheckNewFunc(F2)
CheckNewFunc(F3)
CheckNewFunc(F4)
def testFromLibraryEmptyLib(self):
library = function_pb2.FunctionDefLibrary()
self.assertEqual(len(function.from_library(library)), 0)
def testFromLibraryMissingFuncDef(self):
@function.Defun(dtypes.float32, dtypes.float32)
def G1(x, dy):
return x * dy
@function.Defun(dtypes.float32)
def F1(x):
return math_ops.exp(x) - math_ops.exp(-x)
gradient = function_pb2.GradientDef()
gradient.function_name = F1.name
gradient.gradient_func = G1.name
# Create invalid function def that is missing G1 function def
library = function_pb2.FunctionDefLibrary()
library.gradient.extend([gradient])
library.function.extend([F1.definition])
with self.assertRaisesRegexp(
ValueError,
"FunctionDefLibrary missing 'G1_[0-9a-zA-Z]{8,11}' FunctionDef"):
function.from_library(library)
# Create invalid function def that is missing F1 function def
library = function_pb2.FunctionDefLibrary()
library.gradient.extend([gradient])
library.function.extend([G1.definition])
with self.assertRaisesRegexp(
ValueError,
"FunctionDefLibrary missing 'F1_[0-9a-zA-Z]{8,11}' FunctionDef"):
function.from_library(library)
def testFromLibraryCyclicGradFuncs(self):
@function.Defun(dtypes.float32)
def F1(x):
return math_ops.exp(x) - math_ops.exp(-x)
@function.Defun(dtypes.float32)
def F2(x):
return math_ops.exp(x) - math_ops.exp(-x)
# Create invalid function def library where F1 has gradient function F2 and
# F2 has gradient function F1
library = function_pb2.FunctionDefLibrary()
library.function.extend([F1.definition, F2.definition])
gradient1 = function_pb2.GradientDef()
gradient1.function_name = F1.name
gradient1.gradient_func = F2.name
gradient2 = function_pb2.GradientDef()
gradient2.function_name = F2.name
gradient2.gradient_func = F1.name
library.gradient.extend([gradient1, gradient2])
with self.assertRaisesRegexp(
ValueError, "FunctionDefLibrary contains cyclic gradient functions!"):
function.from_library(library)
def testExperimentalAttrs(self):
@function.Defun(dtypes.int32, experimental_tag="tag_value")
def FunctionWithStrAttr(i):
return array_ops.identity(i)
@function.Defun(dtypes.int32, experimental_tag=123)
def FunctionWithIntAttr(i):
return array_ops.identity(i)
@function.Defun(dtypes.int32, experimental_tag=123.0)
def FunctionWithFloatAttr(i):
return array_ops.identity(i)
@function.Defun(dtypes.int32, experimental_tag=True)
def FunctionWithBoolAttr(i):
return array_ops.identity(i)
self.assertTrue("experimental_tag" in FunctionWithStrAttr.definition.attr)
self.assertEqual(FunctionWithStrAttr.definition.attr["experimental_tag"].s,
b"tag_value")
self.assertTrue("experimental_tag" in FunctionWithIntAttr.definition.attr)
self.assertEqual(FunctionWithIntAttr.definition.attr["experimental_tag"].i,
123)
self.assertTrue("experimental_tag" in FunctionWithFloatAttr.definition.attr)
self.assertEqual(
FunctionWithFloatAttr.definition.attr["experimental_tag"].f, 123.0)
self.assertTrue("experimental_tag" in FunctionWithBoolAttr.definition.attr)
self.assertEqual(FunctionWithBoolAttr.definition.attr["experimental_tag"].b,
True)
class FunctionOverloadTest(test.TestCase):
@test_util.run_deprecated_v1
def testBasic(self):
@function.Defun()
def Sinh(x):
return 1 / 2. * (math_ops.exp(x) - math_ops.exp(-x))
g = ops.Graph()
with g.as_default():
x = Sinh(constant_op.constant(0.25, dtypes.float32))
y = Sinh(constant_op.constant(0.25, dtypes.float64))
with self.session(graph=g):
self.assertAllClose(x.eval(), np.sinh(0.25))
self.assertAllClose(y.eval(), np.sinh(0.25))
def testGradient(self):
@function.Defun(func_name="Spec")
def G(x, dy):
return x * dy
@function.Defun(grad_func=G)
def F(x):
return math_ops.exp(x) - math_ops.exp(-x)
for dtype in [dtypes.float32, dtypes.float64]:
g = ops.Graph()
with g.as_default():
x = constant_op.constant(0.25, dtype)
y = F(x)
dx, = gradients_impl.gradients(y, x)
with self.session(graph=g):
self.assertAllClose(dx.eval(), 0.25)
def testDocString(self):
@function.Defun()
def Foo(x):
"""Successor of x."""
return x + 1
g = ops.Graph()
with g.as_default():
_ = Foo(1)
self.assertEqual(g.as_graph_def().library.function[0].signature.description,
"Successor of x.")
class FunctionCaptureByValueTest(test.TestCase):
@test_util.run_deprecated_v1
def testCaptureByValue(self):
g = ops.Graph()
with g.as_default():
w = constant_op.constant([[1.0]])
b = constant_op.constant([2.0])
# Foo() captures w and b.
@function.Defun(dtypes.float32, capture_by_value=True)
def Foo(x):
# Plus() captures b.
@function.Defun(dtypes.float32, capture_by_value=True)
def Plus(y):
return y + b
self.assertEqual(0, len(Plus.captured_inputs))
return Plus(math_ops.matmul(w, x))
y = Foo(constant_op.constant([[10.]]))
self.assertEqual(0, len(Foo.captured_inputs))
with self.session(graph=g):
self.assertAllEqual(y.eval(), [[12.0]])
class UnrollLSTMTest(test.TestCase):
BATCH_SIZE = 16
LSTM_DIMS = 32
NUM_UNROLL = 20
def _Weights(self):
dims = self.LSTM_DIMS
return random_ops.random_uniform([2 * dims, 4 * dims], -1, 1, seed=123456)
def _Input(self):
return random_ops.random_uniform(
[self.NUM_UNROLL, self.BATCH_SIZE, self.LSTM_DIMS], seed=654321)
# Helper to construct a LSTM cell graph.
@classmethod
def LSTMCell(cls, x, mprev, cprev, weights):
xm = array_ops.concat([x, mprev], 1)
i_i, i_g, f_g, o_g = array_ops.split(
value=math_ops.matmul(xm, weights), num_or_size_splits=4, axis=1)
new_c = math_ops.sigmoid(f_g) * cprev + math_ops.sigmoid(
i_g) * math_ops.tanh(i_i)
new_c = math_ops.maximum(math_ops.minimum(new_c, 50.0), -50.0)
new_m = math_ops.sigmoid(o_g) * math_ops.tanh(new_c)
return new_m, new_c
def _BuildForward(self, weights, inp, mode="cell"):
def Loop(cell, w, i):
x = array_ops.unstack(i, self.NUM_UNROLL)
m = array_ops.zeros_like(x[0])
c = array_ops.zeros_like(x[0])
for i in range(self.NUM_UNROLL):
m, c = cell(x[i], m, c, w)
return m
cell = UnrollLSTMTest.LSTMCell
if mode == "complete":
# Constructs the complete graph in python.
return Loop(cell, weights, inp)
cell = function.Defun(dtypes.float32, dtypes.float32, dtypes.float32,
dtypes.float32)(
cell)
if mode == "cell":
# Just represent the LSTM as a function.
return Loop(cell, weights, inp)
if mode == "loop":
# Wraps the whole loop as a function.
@function.Defun(dtypes.float32, dtypes.float32)
def LSTMLoop(w, i):
return Loop(cell, w, i)
return LSTMLoop(weights, inp)
if mode == "loop10":
# Wraps 10 lstm steps into one function, and the whole loop
# into another calling the formers.
# Groups 10 steps at a time.
@function.Defun(dtypes.float32, dtypes.float32, dtypes.float32,
*([dtypes.float32] * 10))
def Loop10(w, m, c, *args):
for x in args:
m, c = cell(x, m, c, w)
return m, c
@function.Defun(dtypes.float32, dtypes.float32)
def LSTMLoop10(weights, inp):
x = array_ops.unstack(inp, self.NUM_UNROLL)
m = array_ops.zeros_like(x[0])
c = array_ops.zeros_like(x[0])
assert self.NUM_UNROLL % 10 == 0
for i in range(0, self.NUM_UNROLL, 10):
m, c = Loop10(weights, m, c, *x[i:i + 10])
return m
return LSTMLoop10(weights, inp)
def testUnrollLSTM(self):
# Run one step of the unrolled lstm graph.
def RunForward(mode, cfg=None):
tf_logging.info("mode = %s", mode)
g = ops.Graph()
start = time.time()
with g.as_default():
weights = self._Weights()
inp = self._Input()
m = self._BuildForward(weights, inp, mode)
gdef = g.as_graph_def()
finish = time.time()
tf_logging.info("time: %f txt size: %d gdef bin size: %d", finish - start,
len(str(gdef)), len(gdef.SerializeToString()))
with g.as_default(), session.Session(config=cfg) as sess:
return self.evaluate(m)
mv0 = RunForward("complete")
for cfg in _OptimizerOptions():
tf_logging.info("cfg = %s", cfg)
mv1 = RunForward("cell", cfg)
mv2 = RunForward("loop", cfg)
mv3 = RunForward("loop10", cfg)
self.assertAllClose(mv0, mv1, rtol=1e-4)
self.assertAllClose(mv0, mv2, rtol=1e-4)
self.assertAllClose(mv0, mv3, rtol=1e-4)
def testUnrollLSTMGrad(self):
# Run one step of the unrolled lstm graph.
def RunForwardBackward(mode, cfg=None):
tf_logging.info("mode = %s", mode)
g = ops.Graph()
start = time.time()
with g.as_default():
weights = self._Weights()
inp = self._Input()
m = self._BuildForward(weights, inp, mode)
loss = math_ops.reduce_sum(math_ops.square(m))
dw = gradients_impl.gradients([loss], [weights])
gdef = g.as_graph_def()
finish = time.time()
tf_logging.info("time: %f txt size: %d gdef bin size: %d", finish - start,
len(str(gdef)), len(gdef.SerializeToString()))
with g.as_default(), session.Session(config=cfg) as sess:
return self.evaluate(dw)
d0 = RunForwardBackward("complete")
for cfg in _OptimizerOptions():
tf_logging.info("cfg = %s", cfg)
d1 = RunForwardBackward("cell", cfg)
d2 = RunForwardBackward("loop", cfg)
d3 = RunForwardBackward("loop10", cfg)
self.assertAllClose(d0, d1, rtol=1e-4, atol=1e-4)
self.assertAllClose(d0, d2, rtol=1e-4, atol=1e-4)
self.assertAllClose(d0, d3, rtol=1e-4, atol=1e-4)
class FunctionInlineControlTest(test.TestCase):
@test_util.disable_xla("XLA changes the names, breaking graph analysis")
def testFoo(self):
dtype = dtypes.float32
cfg = config_pb2.ConfigProto(
graph_options=config_pb2.GraphOptions(
optimizer_options=config_pb2.OptimizerOptions(
opt_level=config_pb2.OptimizerOptions.L0,
do_common_subexpression_elimination=True,
do_function_inlining=True,
do_constant_folding=True)))
cell_func_call_pattern = re.compile(r"Cell[^/]*\(")
for noinline in [False, True]:
@function.Defun(dtype, noinline=noinline)
def Cell(v):
# If v is a vector [n, 1], x is a big square matrix.
x = math_ops.tanh(v + array_ops.transpose(v, [1, 0]))
return math_ops.reduce_sum(x, 1, keepdims=True)
@function.Defun(dtype)
def Forward(x):
for _ in range(10):
# pylint: disable=cell-var-from-loop
x = Cell(x)
return math_ops.reduce_sum(x, [0, 1])
self.assertEqual(noinline, Cell.definition.attr["_noinline"].b)
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype)
y = Forward(x)
dx, = gradients_impl.gradients([y], [x])
np.random.seed(321)
inp = np.random.uniform(-1, 1, [16, 1]).astype(np.float32)
run_metadata = config_pb2.RunMetadata()
with session.Session(graph=g, config=cfg) as sess:
ans = sess.run(
[y, dx], {x: inp},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
print(ans[0], np.sum(ans[1]))
self.assertAllClose(ans[0], 255.971, rtol=1e-3)
self.assertAllClose(np.sum(ans[1]), 13.0408, rtol=1e-3)
def MetadataHasCell(run_metadata):
for dev_stats in run_metadata.step_stats.dev_stats:
for node_stats in dev_stats.node_stats:
if cell_func_call_pattern.search(node_stats.timeline_label):
return True
return False
self.assertEqual(MetadataHasCell(run_metadata), noinline)
class ModuleFunctionTest(test.TestCase):
@test_util.run_deprecated_v1
def testBasic(self):
@function.Defun(*[dtypes.float32] * 3)
def LinearWithCApi(w, b, x):
return nn_ops.relu(math_ops.matmul(x, w) + b)
@function.Defun(*[dtypes.float32] * 5)
def Linear2WithCApi(w1, b1, w2, b2, x):
return LinearWithCApi(w2, b2, LinearWithCApi(w1, b1, x))
with ops.Graph().as_default():
a, b, c, d, e = [
constant_op.constant([[_]], dtype=dtypes.float32) for _ in range(5)
]
y = LinearWithCApi(a, b, c)
z = Linear2WithCApi(a, b, c, d, e)
with session.Session() as sess:
self.assertAllEqual([[1]], self.evaluate(y))
self.assertAllEqual([[5]], self.evaluate(z))
class VariableHoistingTest(test.TestCase):
def _testSimpleModel(self, use_forward_func, use_resource=False):
def _Model(x):
w = variable_scope.get_variable(
"w", (64, 64),
initializer=init_ops.random_uniform_initializer(seed=312),
use_resource=use_resource)
b = variable_scope.get_variable(
"b", (64),
initializer=init_ops.zeros_initializer(),
use_resource=use_resource),
return math_ops.sigmoid(math_ops.matmul(x, w) + b)
@function.Defun()
def Model(x):
return _Model(x)
cvars = []
@function.Defun()
def Grad(x, y0):
if use_forward_func:
y = Model(x)
else:
y = _Model(x)
loss = math_ops.reduce_mean(
math_ops.reduce_sum(y0 * math_ops.log(y), 1), 0)
arg_w, arg_b = function.get_extra_args()
self.assertEqual(arg_w.get_shape(), tensor_shape.TensorShape([64, 64]))
self.assertEqual(arg_b.get_shape(), tensor_shape.TensorShape([64]))
dw, db = gradients_impl.gradients(loss, [arg_w, arg_b])
cvars.extend(function.get_extra_vars())
return loss, dw, db
g = ops.Graph()
with g.as_default():
x = random_ops.random_normal([64, 64], seed=100)
y0 = random_ops.random_normal([64, 64], seed=200)
with variable_scope.variable_scope("Foo"):
loss, dw, db = Grad(x, y0)
self.assertEqual(2, len(cvars))
w, b = cvars[:2]
self.assertEqual("Foo/w", w.op.name)
self.assertEqual("Foo/b", b.op.name)
with self.session(graph=g) as sess:
self.evaluate(variables.global_variables_initializer())
w, b, x, y0, loss, dw, db = self.evaluate([w, b, x, y0, loss, dw, db])
self.assertAllEqual(w.shape, (64, 64))
self.assertAllClose(np.sum(w), 2050.44)
self.assertAllEqual(b.shape, (64,))
self.assertAllClose(np.sum(b), 0.0)
self.assertAllClose(loss, -2.27, rtol=1e-2)
self.assertAllEqual(dw.shape, (64, 64))
self.assertAllClose(np.sum(dw), -1.04, rtol=1e-2)
self.assertAllEqual(db.shape, (64,))
self.assertAllClose(np.sum(db), 0.509, rtol=1e-2)
@test_util.run_deprecated_v1
def testBasic(self):
self._testSimpleModel(False)
self._testSimpleModel(True)
@test_util.run_deprecated_v1
def testBasicResource(self):
self._testSimpleModel(False, use_resource=True)
self._testSimpleModel(True, use_resource=True)
class TemplateTest(test.TestCase):
@test_util.run_v1_only("make_template not supported in TF2")
def testBasic(self):
self.assertTemplateVariableSharing(use_resource=True, defun_first=False)
@test_util.run_v1_only("make_template not supported in TF2")
def testBasicRef(self):
self.assertTemplateVariableSharing(use_resource=False, defun_first=False)
@test_util.run_v1_only("make_template not supported in TF2")
def testBasicDefunFirst(self):
self.assertTemplateVariableSharing(use_resource=True, defun_first=True)
@test_util.run_v1_only("make_template not supported in TF2")
def testBasicRefDefunFirst(self):
self.assertTemplateVariableSharing(use_resource=False, defun_first=True)
def assertTemplateVariableSharing(self, use_resource, defun_first):
parameters = []
def MakeModel(x):
w = variable_scope.get_variable(
"w", (64, 64),
initializer=init_ops.random_uniform_initializer(seed=312),
use_resource=use_resource)
b = variable_scope.get_variable(
"b", (64),
initializer=init_ops.zeros_initializer(),
use_resource=use_resource)
parameters.extend((w, b))
return math_ops.sigmoid(math_ops.matmul(x, w) + b)
model = template.make_template("f", MakeModel, create_scope_now_=True)
@function.Defun()
def ModelDefun(x):
return model(x)
x = array_ops.placeholder(dtypes.float32)
if defun_first:
ModelDefun(x)
model(x)
else:
model(x)
ModelDefun(x)
w1, b1, w2, b2 = parameters # pylint: disable=unbalanced-tuple-unpacking
self.assertIs(w1, w2)
self.assertIs(b1, b2)
class DevicePlacementTest(test.TestCase):
def testNoDeviceGraph(self):
with ops.Graph().as_default():
@function.Defun(*[dtypes.float32] * 2)
def Matmul(a, b):
return math_ops.matmul(a, b)
Matmul(1., 2.)
gdef = ops.get_default_graph().as_graph_def()
self.assertAllEqual(len(gdef.library.function), 1)
fdef = gdef.library.function[0]
for node in fdef.node_def:
self.assertAllEqual(node.device, "")
def testNestedDevices(self):
with ops.Graph().as_default(), ops.device("CPU:0"):
@function.Defun(*[dtypes.float32] * 2)
def Matmul(a, b):
return math_ops.matmul(a, b)
with ops.device("CPU:1"):
@function.Defun(*[dtypes.float32] * 2)
def Divide(a, b):
return math_ops.divide(a, b)
Divide(Matmul(1., 2.), 3.)
gdef = ops.get_default_graph().as_graph_def()
matmul_fdef = [
f for f in gdef.library.function if "Matmul" in f.signature.name
]
divide_fdef = [
f for f in gdef.library.function if "Divide" in f.signature.name
]
self.assertAllEqual(len(matmul_fdef), 1)
self.assertAllEqual(len(divide_fdef), 1)
for node in matmul_fdef[0].node_def:
self.assertAllEqual(node.device, "/device:CPU:0")
for node in divide_fdef[0].node_def:
self.assertAllEqual(node.device, "/device:CPU:1")
def _testNestedDeviceWithSameFunction(self, func_name):
def MatmulWrap(a, b):
@function.Defun(
func_name=func_name, *[dtypes.int32] * 2)
def Matmul(a, b):
return math_ops.matmul(a, b)
return Matmul(a, b)
with ops.Graph().as_default(), ops.device("CPU:0"):
c = MatmulWrap(1, 2)
with ops.device("CPU:1"):
MatmulWrap(c, 3)
gdef = ops.get_default_graph().as_graph_def()
devices = []
for node in gdef.library.function[0].node_def:
devices.append(node.device)
for node in gdef.library.function[1].node_def:
devices.append(node.device)
self.assertAllEqual(sorted(devices), ["/device:CPU:0", "/device:CPU:1"])
def testFunctionWithName(self):
with self.assertRaises(InvalidArgumentError) as cm:
self._testNestedDeviceWithSameFunction("MatmulTest")
self.assertEqual(
cm.exception.message,
"Cannot add function \'MatmulTest\' because a different "
"function with the same name already exists.")
def testFunctionWithoutName(self):
self._testNestedDeviceWithSameFunction(None)
if __name__ == "__main__":
test.main()