blob: 91f21cb1f36cca174061bbd80593100b5f58bc5c [file] [log] [blame]
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for debugger functionalities in tf.Session with grpc:// URLs.
This test file focuses on the grpc:// debugging of local (non-distributed)
tf.Sessions.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.lib import debug_utils
from tensorflow.python.debug.lib import grpc_debug_test_server
from tensorflow.python.debug.lib import session_debug_testlib
from tensorflow.python.debug.wrappers import framework
from tensorflow.python.debug.wrappers import grpc_wrapper
from tensorflow.python.debug.wrappers import hooks
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.training import monitored_session
class GrpcDebugServerTest(test_util.TensorFlowTestCase):
def testRepeatedRunServerRaisesException(self):
(_, _, _, server_thread,
server) = grpc_debug_test_server.start_server_on_separate_thread(
poll_server=True)
# The server is started asynchronously. It needs to be polled till its state
# has become started.
with self.assertRaisesRegexp(
ValueError, "Server has already started running"):
server.run_server()
server.stop_server().wait()
server_thread.join()
def testRepeatedStopServerRaisesException(self):
(_, _, _, server_thread,
server) = grpc_debug_test_server.start_server_on_separate_thread(
poll_server=True)
server.stop_server().wait()
server_thread.join()
with self.assertRaisesRegexp(ValueError, "Server has already stopped"):
server.stop_server().wait()
def testRunServerAfterStopRaisesException(self):
(_, _, _, server_thread,
server) = grpc_debug_test_server.start_server_on_separate_thread(
poll_server=True)
server.stop_server().wait()
server_thread.join()
with self.assertRaisesRegexp(ValueError, "Server has already stopped"):
server.run_server()
def testStartServerWithoutBlocking(self):
(_, _, _, server_thread,
server) = grpc_debug_test_server.start_server_on_separate_thread(
poll_server=True, blocking=False)
# The thread that starts the server shouldn't block, so we should be able to
# join it before stopping the server.
server_thread.join()
server.stop_server().wait()
class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
@classmethod
def setUpClass(cls):
session_debug_testlib.SessionDebugTestBase.setUpClass()
(cls._server_port, cls._debug_server_url, cls._server_dump_dir,
cls._server_thread,
cls._server) = grpc_debug_test_server.start_server_on_separate_thread()
@classmethod
def tearDownClass(cls):
# Stop the test server and join the thread.
cls._server.stop_server().wait()
cls._server_thread.join()
session_debug_testlib.SessionDebugTestBase.tearDownClass()
def setUp(self):
# Override the dump root as the test server's dump directory.
self._dump_root = self._server_dump_dir
def tearDown(self):
if os.path.isdir(self._server_dump_dir):
shutil.rmtree(self._server_dump_dir)
session_debug_testlib.SessionDebugTestBase.tearDown(self)
def _debug_urls(self, run_number=None):
return ["grpc://localhost:%d" % self._server_port]
def _debug_dump_dir(self, run_number=None):
if run_number is None:
return self._dump_root
else:
return os.path.join(self._dump_root, "run_%d" % run_number)
def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException(self):
sess = session.Session(
config=session_debug_testlib.no_rewrite_session_config())
with self.assertRaisesRegexp(
TypeError, "Expected type str or list in grpc_debug_server_addresses"):
grpc_wrapper.GrpcDebugWrapperSession(sess, 1337)
def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException2(self):
sess = session.Session(
config=session_debug_testlib.no_rewrite_session_config())
with self.assertRaisesRegexp(
TypeError, "Expected type str in list grpc_debug_server_addresses"):
grpc_wrapper.GrpcDebugWrapperSession(sess, ["localhost:1337", 1338])
def testUseInvalidWatchFnTypeWithGrpcDebugWrapperSessionRaisesException(self):
sess = session.Session(
config=session_debug_testlib.no_rewrite_session_config())
with self.assertRaises(TypeError):
grpc_wrapper.GrpcDebugWrapperSession(
sess, "localhost:%d" % self._server_port, watch_fn="foo")
def testGrpcDebugWrapperSessionWithoutWatchFnWorks(self):
u = variables.Variable(2.1, name="u")
v = variables.Variable(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
config=session_debug_testlib.no_rewrite_session_config())
sess.run(u.initializer)
sess.run(v.initializer)
sess = grpc_wrapper.GrpcDebugWrapperSession(
sess, "localhost:%d" % self._server_port)
w_result = sess.run(w)
self.assertAllClose(42.0, w_result)
dump = debug_data.DebugDumpDir(self._dump_root)
self.assertEqual(5, dump.size)
self.assertAllClose([2.1], dump.get_tensors("u", 0, "DebugIdentity"))
self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
self.assertAllClose([20.0], dump.get_tensors("v", 0, "DebugIdentity"))
self.assertAllClose([20.0], dump.get_tensors("v/read", 0, "DebugIdentity"))
self.assertAllClose([42.0], dump.get_tensors("w", 0, "DebugIdentity"))
def testGrpcDebugWrapperSessionWithWatchFnWorks(self):
def watch_fn(feeds, fetch_keys):
del feeds, fetch_keys
return ["DebugIdentity", "DebugNumericSummary"], r".*/read", None
u = variables.Variable(2.1, name="u")
v = variables.Variable(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
config=session_debug_testlib.no_rewrite_session_config())
sess.run(u.initializer)
sess.run(v.initializer)
sess = grpc_wrapper.GrpcDebugWrapperSession(
sess, "localhost:%d" % self._server_port, watch_fn=watch_fn)
w_result = sess.run(w)
self.assertAllClose(42.0, w_result)
dump = debug_data.DebugDumpDir(self._dump_root)
self.assertEqual(4, dump.size)
self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
self.assertEqual(
14, len(dump.get_tensors("u/read", 0, "DebugNumericSummary")[0]))
self.assertAllClose([20.0], dump.get_tensors("v/read", 0, "DebugIdentity"))
self.assertEqual(
14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0]))
def testGrpcDebugHookWithStatelessWatchFnWorks(self):
# Perform some set up. Specifically, construct a simple TensorFlow graph and
# create a watch function for certain ops.
def watch_fn(feeds, fetch_keys):
del feeds, fetch_keys
return framework.WatchOptions(
debug_ops=["DebugIdentity", "DebugNumericSummary"],
node_name_regex_whitelist=r".*/read",
op_type_regex_whitelist=None,
tolerate_debug_op_creation_failures=True)
u = variables.Variable(2.1, name="u")
v = variables.Variable(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
config=session_debug_testlib.no_rewrite_session_config())
sess.run(u.initializer)
sess.run(v.initializer)
# Create a hook. One could use this hook with say a tflearn Estimator.
# However, we use a HookedSession in this test to avoid depending on the
# internal implementation of Estimators.
grpc_debug_hook = hooks.GrpcDebugHook(
["localhost:%d" % self._server_port], watch_fn=watch_fn)
sess = monitored_session._HookedSession(sess, [grpc_debug_hook])
# Run the hooked session. This should stream tensor data to the GRPC
# endpoints.
w_result = sess.run(w)
# Verify that the hook monitored the correct tensors.
self.assertAllClose(42.0, w_result)
dump = debug_data.DebugDumpDir(self._dump_root)
self.assertEqual(4, dump.size)
self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
self.assertEqual(
14, len(dump.get_tensors("u/read", 0, "DebugNumericSummary")[0]))
self.assertAllClose([20.0], dump.get_tensors("v/read", 0, "DebugIdentity"))
self.assertEqual(
14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0]))
def testTensorBoardDebugHookWorks(self):
u = variables.Variable(2.1, name="u")
v = variables.Variable(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
config=session_debug_testlib.no_rewrite_session_config())
sess.run(u.initializer)
sess.run(v.initializer)
grpc_debug_hook = hooks.TensorBoardDebugHook(
["localhost:%d" % self._server_port])
sess = monitored_session._HookedSession(sess, [grpc_debug_hook])
# Activate watch point on a tensor before calling sess.run().
self._server.request_watch("u/read", 0, "DebugIdentity")
self.assertAllClose(42.0, sess.run(w))
# self.assertAllClose(42.0, sess.run(w))
dump = debug_data.DebugDumpDir(self._dump_root)
self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
# Check that the server has received the stack trace.
self.assertTrue(self._server.query_op_traceback("u"))
self.assertTrue(self._server.query_op_traceback("u/read"))
self.assertTrue(self._server.query_op_traceback("v"))
self.assertTrue(self._server.query_op_traceback("v/read"))
self.assertTrue(self._server.query_op_traceback("w"))
# Check that the server has received the python file content.
# Query an arbitrary line to make sure that is the case.
with open(__file__, "rt") as this_source_file:
first_line = this_source_file.readline().strip()
self.assertEqual(
first_line, self._server.query_source_file_line(__file__, 1))
self._server.clear_data()
# Call sess.run() again, and verify that this time the traceback and source
# code is not sent, because the graph version is not newer.
self.assertAllClose(42.0, sess.run(w))
with self.assertRaises(ValueError):
self._server.query_op_traceback("delta_1")
with self.assertRaises(ValueError):
self._server.query_source_file_line(__file__, 1)
def testTensorBoardDebugHookDisablingTracebackSourceCodeSendingWorks(self):
u = variables.Variable(2.1, name="u")
v = variables.Variable(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
config=session_debug_testlib.no_rewrite_session_config())
sess.run(variables.global_variables_initializer())
grpc_debug_hook = hooks.TensorBoardDebugHook(
["localhost:%d" % self._server_port],
send_traceback_and_source_code=False)
sess = monitored_session._HookedSession(sess, [grpc_debug_hook])
# Activate watch point on a tensor before calling sess.run().
self._server.request_watch("u/read", 0, "DebugIdentity")
self.assertAllClose(42.0, sess.run(w))
# Check that the server has _not_ received any tracebacks, as a result of
# the disabling above.
with self.assertRaisesRegexp(
ValueError, r"Op .*u/read.* does not exist"):
self.assertTrue(self._server.query_op_traceback("u/read"))
with self.assertRaisesRegexp(
ValueError, r".* has not received any source file"):
self._server.query_source_file_line(__file__, 1)
def testConstructGrpcDebugHookWithOrWithouGrpcInUrlWorks(self):
hooks.GrpcDebugHook(["grpc://foo:42424"])
hooks.GrpcDebugHook(["foo:42424"])
class SessionDebugConcurrentTest(
session_debug_testlib.DebugConcurrentRunCallsTest):
@classmethod
def setUpClass(cls):
session_debug_testlib.SessionDebugTestBase.setUpClass()
(cls._server_port, cls._debug_server_url, cls._server_dump_dir,
cls._server_thread,
cls._server) = grpc_debug_test_server.start_server_on_separate_thread()
@classmethod
def tearDownClass(cls):
# Stop the test server and join the thread.
cls._server.stop_server().wait()
cls._server_thread.join()
session_debug_testlib.SessionDebugTestBase.tearDownClass()
def setUp(self):
self._num_concurrent_runs = 3
self._dump_roots = []
for i in range(self._num_concurrent_runs):
self._dump_roots.append(
os.path.join(self._server_dump_dir, "thread%d" % i))
def tearDown(self):
ops.reset_default_graph()
if os.path.isdir(self._server_dump_dir):
shutil.rmtree(self._server_dump_dir)
def _get_concurrent_debug_urls(self):
urls = []
for i in range(self._num_concurrent_runs):
urls.append(self._debug_server_url + "/thread%d" % i)
return urls
class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
"""Test server gating of debug ops."""
@classmethod
def setUpClass(cls):
(cls._server_port_1, cls._debug_server_url_1, _, cls._server_thread_1,
cls._server_1) = grpc_debug_test_server.start_server_on_separate_thread(
dump_to_filesystem=False)
(cls._server_port_2, cls._debug_server_url_2, _, cls._server_thread_2,
cls._server_2) = grpc_debug_test_server.start_server_on_separate_thread(
dump_to_filesystem=False)
cls._servers_and_threads = [(cls._server_1, cls._server_thread_1),
(cls._server_2, cls._server_thread_2)]
@classmethod
def tearDownClass(cls):
for server, thread in cls._servers_and_threads:
server.stop_server().wait()
thread.join()
def tearDown(self):
ops.reset_default_graph()
self._server_1.clear_data()
self._server_2.clear_data()
def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenDebugNodes(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
v_1 = variables.Variable(50.0, name="v_1")
v_2 = variables.Variable(-50.0, name="v_1")
delta_1 = constant_op.constant(5.0, name="delta_1")
delta_2 = constant_op.constant(-5.0, name="delta_2")
inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")
sess.run([v_1.initializer, v_2.initializer])
run_metadata = config_pb2.RunMetadata()
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(
run_options,
sess.graph,
debug_ops=["DebugIdentity(gated_grpc=true)",
"DebugNumericSummary(gated_grpc=true)"],
debug_urls=[self._debug_server_url_1])
for i in xrange(4):
self._server_1.clear_data()
if i % 2 == 0:
self._server_1.request_watch("delta_1", 0, "DebugIdentity")
self._server_1.request_watch("delta_2", 0, "DebugIdentity")
self._server_1.request_unwatch("delta_1", 0, "DebugNumericSummary")
self._server_1.request_unwatch("delta_2", 0, "DebugNumericSummary")
else:
self._server_1.request_unwatch("delta_1", 0, "DebugIdentity")
self._server_1.request_unwatch("delta_2", 0, "DebugIdentity")
self._server_1.request_watch("delta_1", 0, "DebugNumericSummary")
self._server_1.request_watch("delta_2", 0, "DebugNumericSummary")
sess.run([inc_v_1, inc_v_2],
options=run_options, run_metadata=run_metadata)
# Watched debug tensors are:
# Run 0: delta_[1,2]:0:DebugIdentity
# Run 1: delta_[1,2]:0:DebugNumericSummary
# Run 2: delta_[1,2]:0:DebugIdentity
# Run 3: delta_[1,2]:0:DebugNumericSummary
self.assertEqual(2, len(self._server_1.debug_tensor_values))
if i % 2 == 0:
self.assertAllClose(
[5.0],
self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"])
self.assertAllClose(
[-5.0],
self._server_1.debug_tensor_values["delta_2:0:DebugIdentity"])
else:
self.assertAllClose(
[[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 5.0, 5.0, 5.0,
0.0, 1.0, 0.0]],
self._server_1.debug_tensor_values[
"delta_1:0:DebugNumericSummary"])
self.assertAllClose(
[[1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, -5.0, -5.0, -5.0,
0.0, 1.0, 0.0]],
self._server_1.debug_tensor_values[
"delta_2:0:DebugNumericSummary"])
def testToggleWatchesOnCoreMetadata(self):
(_, debug_server_url, _, server_thread,
server) = grpc_debug_test_server.start_server_on_separate_thread(
dump_to_filesystem=False,
toggle_watch_on_core_metadata=[("toggled_1", 0, "DebugIdentity"),
("toggled_2", 0, "DebugIdentity")])
self._servers_and_threads.append((server, server_thread))
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
v_1 = variables.Variable(50.0, name="v_1")
v_2 = variables.Variable(-50.0, name="v_1")
# These two nodes have names that match those in the
# toggle_watch_on_core_metadata argument used when calling
# start_server_on_separate_thread().
toggled_1 = constant_op.constant(5.0, name="toggled_1")
toggled_2 = constant_op.constant(-5.0, name="toggled_2")
inc_v_1 = state_ops.assign_add(v_1, toggled_1, name="inc_v_1")
inc_v_2 = state_ops.assign_add(v_2, toggled_2, name="inc_v_2")
sess.run([v_1.initializer, v_2.initializer])
run_metadata = config_pb2.RunMetadata()
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(
run_options,
sess.graph,
debug_ops=["DebugIdentity(gated_grpc=true)"],
debug_urls=[debug_server_url])
for i in xrange(4):
server.clear_data()
sess.run([inc_v_1, inc_v_2],
options=run_options, run_metadata=run_metadata)
if i % 2 == 0:
self.assertEqual(2, len(server.debug_tensor_values))
self.assertAllClose(
[5.0],
server.debug_tensor_values["toggled_1:0:DebugIdentity"])
self.assertAllClose(
[-5.0],
server.debug_tensor_values["toggled_2:0:DebugIdentity"])
else:
self.assertEqual(0, len(server.debug_tensor_values))
def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenServers(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
v = variables.Variable(50.0, name="v")
delta = constant_op.constant(5.0, name="delta")
inc_v = state_ops.assign_add(v, delta, name="inc_v")
sess.run(v.initializer)
run_metadata = config_pb2.RunMetadata()
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(
run_options,
sess.graph,
debug_ops=["DebugIdentity(gated_grpc=true)"],
debug_urls=[self._debug_server_url_1, self._debug_server_url_2])
for i in xrange(4):
self._server_1.clear_data()
self._server_2.clear_data()
if i % 2 == 0:
self._server_1.request_watch("delta", 0, "DebugIdentity")
self._server_2.request_watch("v", 0, "DebugIdentity")
else:
self._server_1.request_unwatch("delta", 0, "DebugIdentity")
self._server_2.request_unwatch("v", 0, "DebugIdentity")
sess.run(inc_v, options=run_options, run_metadata=run_metadata)
if i % 2 == 0:
self.assertEqual(1, len(self._server_1.debug_tensor_values))
self.assertEqual(1, len(self._server_2.debug_tensor_values))
self.assertAllClose(
[5.0],
self._server_1.debug_tensor_values["delta:0:DebugIdentity"])
self.assertAllClose(
[50 + 5.0 * i],
self._server_2.debug_tensor_values["v:0:DebugIdentity"])
else:
self.assertEqual(0, len(self._server_1.debug_tensor_values))
self.assertEqual(0, len(self._server_2.debug_tensor_values))
def testToggleBreakpointsWorks(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
v_1 = variables.Variable(50.0, name="v_1")
v_2 = variables.Variable(-50.0, name="v_2")
delta_1 = constant_op.constant(5.0, name="delta_1")
delta_2 = constant_op.constant(-5.0, name="delta_2")
inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")
sess.run([v_1.initializer, v_2.initializer])
run_metadata = config_pb2.RunMetadata()
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(
run_options,
sess.graph,
debug_ops=["DebugIdentity(gated_grpc=true)"],
debug_urls=[self._debug_server_url_1])
for i in xrange(4):
self._server_1.clear_data()
if i in (0, 2):
# Enable breakpoint at delta_[1,2]:0:DebugIdentity in runs 0 and 2.
self._server_1.request_watch(
"delta_1", 0, "DebugIdentity", breakpoint=True)
self._server_1.request_watch(
"delta_2", 0, "DebugIdentity", breakpoint=True)
else:
# Disable the breakpoint in runs 1 and 3.
self._server_1.request_unwatch("delta_1", 0, "DebugIdentity")
self._server_1.request_unwatch("delta_2", 0, "DebugIdentity")
output = sess.run([inc_v_1, inc_v_2],
options=run_options, run_metadata=run_metadata)
self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output)
if i in (0, 2):
# During runs 0 and 2, the server should have received the published
# debug tensor delta:0:DebugIdentity. The breakpoint should have been
# unblocked by EventReply reponses from the server.
self.assertAllClose(
[5.0],
self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"])
self.assertAllClose(
[-5.0],
self._server_1.debug_tensor_values["delta_2:0:DebugIdentity"])
# After the runs, the server should have properly registered the
# breakpoints due to the request_unwatch calls.
self.assertSetEqual({("delta_1", 0, "DebugIdentity"),
("delta_2", 0, "DebugIdentity")},
self._server_1.breakpoints)
else:
# After the end of runs 1 and 3, the server has received the requests
# to disable the breakpoint at delta:0:DebugIdentity.
self.assertSetEqual(set(), self._server_1.breakpoints)
def testTensorBoardDebuggerWrapperToggleBreakpointsWorks(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
v_1 = variables.Variable(50.0, name="v_1")
v_2 = variables.Variable(-50.0, name="v_2")
delta_1 = constant_op.constant(5.0, name="delta_1")
delta_2 = constant_op.constant(-5.0, name="delta_2")
inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")
sess.run([v_1.initializer, v_2.initializer])
# The TensorBoardDebugWrapperSession should add a DebugIdentity debug op
# with attribute gated_grpc=True for every tensor in the graph.
sess = grpc_wrapper.TensorBoardDebugWrapperSession(
sess, self._debug_server_url_1)
for i in xrange(4):
self._server_1.clear_data()
if i in (0, 2):
# Enable breakpoint at delta_[1,2]:0:DebugIdentity in runs 0 and 2.
self._server_1.request_watch(
"delta_1", 0, "DebugIdentity", breakpoint=True)
self._server_1.request_watch(
"delta_2", 0, "DebugIdentity", breakpoint=True)
else:
# Disable the breakpoint in runs 1 and 3.
self._server_1.request_unwatch("delta_1", 0, "DebugIdentity")
self._server_1.request_unwatch("delta_2", 0, "DebugIdentity")
output = sess.run([inc_v_1, inc_v_2])
self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output)
if i in (0, 2):
# During runs 0 and 2, the server should have received the published
# debug tensor delta:0:DebugIdentity. The breakpoint should have been
# unblocked by EventReply reponses from the server.
self.assertAllClose(
[5.0],
self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"])
self.assertAllClose(
[-5.0],
self._server_1.debug_tensor_values["delta_2:0:DebugIdentity"])
# After the runs, the server should have properly registered the
# breakpoints.
else:
# After the end of runs 1 and 3, the server has received the requests
# to disable the breakpoint at delta:0:DebugIdentity.
self.assertSetEqual(set(), self._server_1.breakpoints)
if i == 0:
# Check that the server has received the stack trace.
self.assertTrue(self._server_1.query_op_traceback("delta_1"))
self.assertTrue(self._server_1.query_op_traceback("delta_2"))
self.assertTrue(self._server_1.query_op_traceback("inc_v_1"))
self.assertTrue(self._server_1.query_op_traceback("inc_v_2"))
# Check that the server has received the python file content.
# Query an arbitrary line to make sure that is the case.
with open(__file__, "rt") as this_source_file:
first_line = this_source_file.readline().strip()
self.assertEqual(
first_line, self._server_1.query_source_file_line(__file__, 1))
else:
# In later Session.run() calls, the traceback shouldn't have been sent
# because it is already sent in the 1st call. So calling
# query_op_traceback() should lead to an exception, because the test
# debug server clears the data at the beginning of every iteration.
with self.assertRaises(ValueError):
self._server_1.query_op_traceback("delta_1")
with self.assertRaises(ValueError):
self._server_1.query_source_file_line(__file__, 1)
def testTensorBoardDebuggerWrapperDisablingTracebackSourceSendingWorks(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
v_1 = variables.Variable(50.0, name="v_1")
v_2 = variables.Variable(-50.0, name="v_2")
delta_1 = constant_op.constant(5.0, name="delta_1")
delta_2 = constant_op.constant(-5.0, name="delta_2")
inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")
sess.run(variables.global_variables_initializer())
# Disable the sending of traceback and source code.
sess = grpc_wrapper.TensorBoardDebugWrapperSession(
sess, self._debug_server_url_1, send_traceback_and_source_code=False)
for i in xrange(4):
self._server_1.clear_data()
if i == 0:
self._server_1.request_watch(
"delta_1", 0, "DebugIdentity", breakpoint=True)
output = sess.run([inc_v_1, inc_v_2])
self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output)
# No op traceback or source code should have been received by the debug
# server due to the disabling above.
with self.assertRaisesRegexp(
ValueError, r"Op .*delta_1.* does not exist"):
self.assertTrue(self._server_1.query_op_traceback("delta_1"))
with self.assertRaisesRegexp(
ValueError, r".* has not received any source file"):
self._server_1.query_source_file_line(__file__, 1)
def testGetGrpcDebugWatchesReturnsCorrectAnswer(self):
with session.Session() as sess:
v = variables.Variable(50.0, name="v")
delta = constant_op.constant(5.0, name="delta")
inc_v = state_ops.assign_add(v, delta, name="inc_v")
sess.run(v.initializer)
# Before any debugged runs, the server should be aware of no debug
# watches.
self.assertEqual([], self._server_1.gated_grpc_debug_watches())
run_metadata = config_pb2.RunMetadata()
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.add_debug_tensor_watch(
run_options, "delta", output_slot=0,
debug_ops=["DebugNumericSummary(gated_grpc=true)"],
debug_urls=[self._debug_server_url_1])
debug_utils.add_debug_tensor_watch(
run_options, "v", output_slot=0,
debug_ops=["DebugIdentity"],
debug_urls=[self._debug_server_url_1])
sess.run(inc_v, options=run_options, run_metadata=run_metadata)
# After the first run, the server should have noted the debug watches
# for which gated_grpc == True, but not the ones with gated_grpc == False.
self.assertEqual(1, len(self._server_1.gated_grpc_debug_watches()))
debug_watch = self._server_1.gated_grpc_debug_watches()[0]
self.assertEqual("delta", debug_watch.node_name)
self.assertEqual(0, debug_watch.output_slot)
self.assertEqual("DebugNumericSummary", debug_watch.debug_op)
class DelayedDebugServerTest(test_util.TensorFlowTestCase):
def testDebuggedSessionRunWorksWithDelayedDebugServerStartup(self):
"""Test debugged Session.run() tolerates delayed debug server startup."""
ops.reset_default_graph()
# Start a debug server asynchronously, with a certain amount of delay.
(debug_server_port, _, _, server_thread,
debug_server) = grpc_debug_test_server.start_server_on_separate_thread(
server_start_delay_sec=2.0, dump_to_filesystem=False)
with self.cached_session() as sess:
a_init = constant_op.constant(42.0, name="a_init")
a = variables.Variable(a_init, name="a")
def watch_fn(fetches, feeds):
del fetches, feeds
return framework.WatchOptions(debug_ops=["DebugIdentity"])
sess = grpc_wrapper.GrpcDebugWrapperSession(
sess, "localhost:%d" % debug_server_port, watch_fn=watch_fn)
sess.run(a.initializer)
self.assertAllClose(
[42.0], debug_server.debug_tensor_values["a_init:0:DebugIdentity"])
debug_server.stop_server().wait()
server_thread.join()
if __name__ == "__main__":
googletest.main()