blob: de68ff0e32cd59e65094c0b7319f8ab213eed4db [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for JIT compilation on the CPU and GPU devices."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensorflow.contrib.compiler import jit
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
jit_scope = jit.experimental_jit_scope
# Disable rewrites to make sure we don't end up having to update this test
# whenever we implement new ones.
def NoRewriteSessionConfig():
rewriter_config = rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True,
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
function_optimization=rewriter_config_pb2.RewriterConfig.OFF)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
def CompiledKernel(fn, *inputs, **kwargs):
"""Execute 'fn' as a compiled XLA kernel, with 'inputs'."""
name = kwargs.pop("name", None)
noinline = kwargs.pop("noinline", None)
@function.Defun(func_name=name, noinline=noinline, compiled=True)
def Compiled(*args):
return fn(*args)
return Compiled(*inputs)
def RunMetadataLabels(run_metadata):
"""Returns all labels in run_metadata."""
labels = []
for dev_stats in run_metadata.step_stats.dev_stats:
for node_stats in dev_stats.node_stats:
labels.append(node_stats.timeline_label)
return labels
def InLabels(labels, substr):
"""Returns true iff one of the labels contains substr."""
return any([substr in x for x in labels])
def MetadataHasXlaOp(run_metadata):
"""Returns true if there are XlaRun kernels in run_metadata's timeline."""
# TODO(phawkins): find a less hacky way to test whether a kernel ran.
return InLabels(RunMetadataLabels(run_metadata), "XlaRun")
class JitLaunchTest(test.TestCase):
# Evaluates 'fn' on 'args' both directly and as a compiled XLA kernel.
# Verifies that the outputs match and that XLA was invoked. 'fn' must take
# the same number of tensors as arguments that are in 'args', and must return
# a tuple of output tensors.
#
# If 'require_kernel_launch' is True, then we verify that an XlaCompile/XlaRun
# node actually ran. However, it is sometimes possible for XlaCompile/XlaRun
# ops to be constant-folded away, so the check is optional.
def _compare(self, fn, args, require_kernel_launch=True, noinline=None):
with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
placeholders = []
feeds = {}
for arg in args:
placeholder = array_ops.placeholder(
dtypes.as_dtype(arg.dtype), list(arg.shape))
placeholders.append(placeholder)
feeds[placeholder] = arg
compiled_op = CompiledKernel(fn, *placeholders, noinline=noinline)
direct_op = fn(*placeholders)
run_metadata = config_pb2.RunMetadata()
compiled = sess.run(compiled_op,
feeds,
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
print("Compiled Result {}".format(compiled))
if require_kernel_launch:
self.assert_(MetadataHasXlaOp(run_metadata))
direct = sess.run(direct_op, feeds)
print("Direct Result {}".format(direct))
if (isinstance(compiled, (tuple, list)) and
(isinstance(direct, (tuple, list)))):
for (x, y) in zip(compiled, direct):
self.assertAllClose(x, y, rtol=1e-1)
else:
self.assertAllClose(compiled, direct, rtol=1e-2)
def testNoOutputs(self):
with session_lib.Session() as sess:
# Check that calling the result as a compiled kernel doesn't crash.
@function.Defun(compiled=True)
def KernelWithNoOutputs():
a = constant_op.constant(100) # pylint: disable=unused-variable
call = KernelWithNoOutputs() # pylint: disable=assignment-from-no-return
sess.run(call, {})
def testAliasing(self):
"""Regression test for compiled functions that return an aliased buffer.
XLA returns aliased buffers if outputs are identical. Tests that
we handle that case.
"""
def AddOnceReturnTwice(x):
y = math_ops.add(x, x)
return y, y
# Exercises compiling a function (say, Foo) which calls another function
# (say, Bar) which is not inlined. When the compiler compiles Foo, it needs
# to symbolically execute Bar correctly regardless of whether Bar is inlined
# or not.
# TODO(b/36139787): Re-enable this test when noinline works again.
# Tests compiled=True and noinline=True.
# self._compare(
# AddOnceReturnTwice, [np.array(
# [[[0.5, -1.0]]], dtype=np.float32)],
# noinline=True)
# Tests compiled=True and noinline=False.
self._compare(
AddOnceReturnTwice, [np.array(
[[[0.5, -1.0]]], dtype=np.float32)],
noinline=False)
def testOneConstOutput(self):
"""Test consisting of a single constant return value."""
def OneConstOutput():
return constant_op.constant([-3, 44, 99])
self._compare(OneConstOutput, [], require_kernel_launch=False)
def testConstZeroElementOutput(self):
"""Test consisting of a constant zero element return value."""
def ConstZeroElementOutput():
return array_ops.fill([7, 0], 3.0)
self._compare(ConstZeroElementOutput, [], require_kernel_launch=False)
def testSomeConstOutputs(self):
"""Test kernels that return a mixture of const and non-const outputs."""
def SomeConstOutputs(x):
return constant_op.constant(
[-2, 7]), array_ops.identity(x), constant_op.constant(3.5)
self._compare(
SomeConstOutputs, [np.array(
[[1, 2, 3], [4, 5, 6]], dtype=np.float32)])
def testInt32Input(self):
"""Test an int32-typed input.
On a GPU, int32 tensors will be placed in host memory.
"""
def AddToSelf(x):
return math_ops.add(x, x)
self._compare(AddToSelf, [np.array([7, 1, 3], dtype=np.int32)])
def testMandatoryConstantInput(self):
"""Tests an operator that has a mandatory-constant shape input."""
def FillWithFloat(x):
return array_ops.fill(x, 9.5)
self._compare(FillWithFloat, [np.array([3, 2], dtype=np.int32)])
def testMnistForwardFunc(self):
"""Compute inference function from MNIST beginners tutorial."""
batch_size = 16
image_size = 28 * 28
num_classes = 10
# Define a TensorFlow function to compute the forward pass.
def MnistForward(w, b, x):
return nn_ops.softmax(math_ops.matmul(x, w) + b)
w = np.random.random_sample((image_size, num_classes)).astype(np.float32)
b = np.random.random_sample((num_classes)).astype(np.float32)
x = np.random.random_sample((batch_size, image_size)).astype(np.float32)
self._compare(MnistForward, [w, b, x])
def testExplicitMarking(self):
"""Test explicit marking of operators to compile."""
batch_size = 16
image_size = 28 * 28
num_classes = 10
with ops.Graph().as_default():
x = array_ops.placeholder(dtypes.float32)
w = array_ops.placeholder(dtypes.float32)
b = array_ops.placeholder(dtypes.float32)
with jit_scope():
y1 = math_ops.matmul(x, w)
y2 = math_ops.add(y1, b)
with jit_scope():
y = math_ops.square(y2)
dw = np.random.random_sample((image_size, num_classes)).astype(np.float32)
db = np.random.random_sample((num_classes)).astype(np.float32)
dx = np.random.random_sample((batch_size, image_size)).astype(np.float32)
with session_lib.Session() as sess:
run_metadata = config_pb2.RunMetadata()
output = sess.run(y, {x: dx,
w: dw,
b: db},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
# TODO(phawkins): really we would like to test that there were exactly
# two kernel launches. However, we have no reliable way to determine
# that.
self.assert_(MetadataHasXlaOp(run_metadata))
expected = np.square(np.dot(dx, dw) + db)
self.assertAllClose(expected, output, rtol=1e-1)
class XlaCompilationTest(test.TestCase):
"""Tests for auto-compilation on CPU/GPU devices."""
def testReshape(self):
"""Tests an operator with compile-time constant and non-constant inputs."""
with self.test_session(config=NoRewriteSessionConfig()) as sess:
x = array_ops.placeholder(dtypes.float32)
y = array_ops.placeholder(dtypes.int32)
with jit_scope():
# Reshape's first argument is non-constant in the JIT, but its second
# (shape) argument will be treated as a compile-time constant for
# each JIT compilation.
# We do not use a tf.const() argument since we want to ensure the
# shape is still a run-time argument to the JIT, and not
# statically known as part of the JIT compilation's input graph.
z = array_ops.reshape(x, y)
run_metadata = config_pb2.RunMetadata()
out = sess.run(z,
{x: np.array([1, 2, 3, 4, 5, 6], np.float32),
y: [-1, 3]},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out)
def testIgnoredArguments(self):
"""Tests that JIT computations can ignore formal parameters."""
with self.test_session(config=NoRewriteSessionConfig()) as sess:
x = array_ops.placeholder(dtypes.int32)
y = array_ops.placeholder(dtypes.int32)
with jit_scope():
z = math_ops.add(x, x)
w = math_ops.add(y, y)
# Pulls 'w' into the same compilation via control dependencies.
with ops.control_dependencies([w]):
n = control_flow_ops.no_op()
with ops.control_dependencies([n]):
t = math_ops.add(z, z)
run_metadata = config_pb2.RunMetadata()
out = sess.run(t, {x: np.int32(7),
y: np.int32(404)},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(28, out)
def testLoops(self):
"""Tests that compilation accepts computations containing loops."""
with self.test_session(config=NoRewriteSessionConfig()) as session:
x = array_ops.placeholder(dtypes.float32)
with jit_scope():
c = lambda i, _: math_ops.less(i, 5)
b = lambda i, x: (i + 1, x * 2.0 + 1.0)
_, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x))
run_metadata = config_pb2.RunMetadata()
result = session.run(y, {x: np.float32(2)},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(result, np.float32(95), rtol=1e-1)
def testCond(self):
"""Tests that compilation handles switch operators."""
with self.test_session(config=NoRewriteSessionConfig()) as session:
x = array_ops.placeholder(dtypes.float32)
y = array_ops.placeholder(dtypes.float32)
c = array_ops.placeholder(dtypes.bool)
with jit_scope():
z = x + 1.0
w = control_flow_ops.cond(c, lambda: z, lambda: y)
t = math_ops.add(z, w)
# If JIT compilation chooses to cluster z and t, then execution will
# deadlock.
run_metadata = config_pb2.RunMetadata()
result = session.run(t, {x: np.float32(2),
y: np.float32(4),
c: True},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(result, np.float32(6), rtol=1e-1)
def testNestedFunction(self):
g = ops.Graph()
with g.as_default():
@function.Defun(compiled=True)
def Bar(x, y):
return x + 2 * y
@function.Defun(compiled=True)
def Foo(x):
return Bar(x * x, x * x * x)
@function.Defun()
def Entry(x):
return Foo(x)
inp = array_ops.placeholder(dtypes.float32)
out = Entry(inp)
with self.test_session(
config=NoRewriteSessionConfig(), graph=g, use_gpu=True) as sess:
run_metadata = config_pb2.RunMetadata()
val = sess.run(out,
feed_dict={inp: [2., 10.]},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
self.assertAllClose(val, [20., 2100.])
def testLoopDeadlock(self):
"""Regression test for bug that caused deadlocks in graphs with loops."""
with self.test_session(config=NoRewriteSessionConfig()) as session:
x = array_ops.placeholder(dtypes.float32)
with jit_scope():
y = x + 1.0
c = lambda i, _x, _y: math_ops.less(i, 5)
b = lambda i, x, _y: (i + 1, x * 2.0 + 1.0, x - 3.0)
_, _, w = control_flow_ops.while_loop(c, b,
(constant_op.constant(0), y, x))
u = w + y
result = session.run(u, {x: np.float32(2)})
self.assertAllClose(result, np.float32(63), rtol=1e-1)
def testGradient(self):
"""Tests that the backprop function is properly compiled."""
def _Run(compiled):
@function.Defun(compiled=compiled)
def Forward(x):
return math_ops.log(x)
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtypes.float32)
y = Forward(x)
dx, = gradients_impl.gradients(y, [x], 1.0)
cfg = NoRewriteSessionConfig()
cfg.graph_options.optimizer_options.opt_level = (
config_pb2.OptimizerOptions.L1)
cfg.graph_options.optimizer_options.do_function_inlining = True
with session_lib.Session(graph=g, config=cfg) as sess:
run_metadata = config_pb2.RunMetadata()
dx_val = sess.run(dx,
feed_dict={x: 100.},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
self.assertAllClose(dx_val, 0.01)
return RunMetadataLabels(run_metadata)
# SymGrad[f=log(x)](x, dy) = 1/x * dy
#
# Note: we don't need to compute log(x) for dx due to graph pruning.
# Do not compile the backprop. We should see one Reciprocal and one Mul.
labels = _Run(compiled=False)
self.assertFalse(InLabels(labels, "Log"))
self.assertTrue(InLabels(labels, "Reciprocal"))
self.assertTrue(InLabels(labels, "Mul"))
self.assertFalse(InLabels(labels, "XlaCompile"))
self.assertFalse(InLabels(labels, "XlaRun"))
# Compile the backprop. One XlaCompile/XlaRun pair.
labels = _Run(compiled=True)
self.assertFalse(InLabels(labels, "Log"))
self.assertFalse(InLabels(labels, "Reciprocal"))
self.assertFalse(InLabels(labels, "Mul"))
self.assertTrue(InLabels(labels, "XlaCompile"))
self.assertTrue(InLabels(labels, "XlaRun"))
class ElementWiseFusionTest(test.TestCase):
# Runs a simple test with the input jit_level and fusion_only flag.
def simpleTest(self, arg0, arg1, global_jit_level):
config = config_pb2.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = global_jit_level
with session_lib.Session(config=config) as sess:
a1 = array_ops.placeholder(dtypes.float32, [2, 2], name="a1")
a2 = array_ops.placeholder(dtypes.float32, [2, 2], name="a2")
# Two element-wise ops. We need at least two ops since single
# element clusters are not passed to XLA in fusion_only mode.
a3 = a1 * a2
a4 = a3 + a1
# A matmul to break XLA clustering.
a5 = math_ops.matmul(a4, a1)
# Two more element-wise ops.
a6 = a5 - a4
a7 = a6 + a2
run_metadata = config_pb2.RunMetadata()
output = sess.run(
a7, {
a1: arg0,
a2: arg1
},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = RunMetadataLabels(run_metadata)
xla_compile_count = sum("XlaCompile(" in x for x in labels)
xla_run_count = sum("XlaRun(" in x for x in labels)
self.assertEqual(xla_compile_count, xla_run_count)
return output, xla_run_count
def testElementWiseClustering(self):
arg0 = np.random.rand(2, 2).astype(np.float32)
arg1 = np.random.rand(2, 2).astype(np.float32)
os.environ["TF_XLA_FLAGS"] = (
"--tf_xla_fusion_only=true "
"--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", ""))
tf_op, tf_count = self.simpleTest(arg0, arg1,
config_pb2.OptimizerOptions.OFF)
self.assertEqual(0, tf_count)
tfef_op, tfef_count = self.simpleTest(arg0, arg1,
config_pb2.OptimizerOptions.ON_1)
self.assertEqual(2, tfef_count)
self.assertAllClose(tf_op, tfef_op, rtol=1e-1)
if __name__ == "__main__":
test.main()