[tfdbg] Add mnist and fib examples for tensorflow v2
Both examples replicate the same behavior of the tensorflow
v1 examples.
PiperOrigin-RevId: 269385509
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 4e4756b..988bfc7 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -51,6 +51,7 @@
":source_remote",
] + if_not_windows([
":debug_examples_v1",
+ ":debug_examples_v2",
]),
)
@@ -432,6 +433,14 @@
]),
)
+py_library(
+ name = "debug_examples_v2",
+ deps = [
+ ":debug_fibonacci_lib",
+ ":debug_mnist_lib",
+ ],
+)
+
py_binary(
name = "debug_fibonacci",
srcs = ["examples/v1/debug_fibonacci.py"],
@@ -440,9 +449,20 @@
deps = [":debug_fibonacci_lib"],
)
+py_binary(
+ name = "debug_fibonacci_v2",
+ srcs = ["examples/v2/debug_fibonacci_v2.py"],
+ python_version = "PY2",
+ srcs_version = "PY2AND3",
+ deps = [":debug_fibonacci_lib"],
+)
+
py_library(
name = "debug_fibonacci_lib",
- srcs = ["examples/v1/debug_fibonacci.py"],
+ srcs = [
+ "examples/v1/debug_fibonacci.py",
+ "examples/v2/debug_fibonacci_v2.py",
+ ],
srcs_version = "PY2AND3",
deps = [
":debug_py",
@@ -487,11 +507,20 @@
deps = [":debug_mnist_lib"],
)
+py_binary(
+ name = "debug_mnist_v2",
+ srcs = ["examples/v2/debug_mnist_v2.py"],
+ python_version = "PY2",
+ srcs_version = "PY2AND3",
+ deps = [":debug_mnist_lib"],
+)
+
py_library(
name = "debug_mnist_lib",
srcs = [
"examples/debug_mnist.py",
"examples/v1/debug_mnist_v1.py",
+ "examples/v2/debug_mnist_v2.py",
],
srcs_version = "PY2AND3",
deps = [
@@ -1231,3 +1260,16 @@
"no_windows",
],
)
+
+sh_test(
+ name = "examples_v2_test",
+ size = "medium",
+ srcs = ["examples/v2/examples_v2_test.sh"],
+ data = [
+ ":debug_fibonacci_v2",
+ ":debug_mnist_v2",
+ ],
+ tags = [
+ "no_windows",
+ ],
+)
diff --git a/tensorflow/python/debug/examples/debug_mnist.py b/tensorflow/python/debug/examples/debug_mnist.py
index 23e307e..c7e51d1e 100644
--- a/tensorflow/python/debug/examples/debug_mnist.py
+++ b/tensorflow/python/debug/examples/debug_mnist.py
@@ -19,9 +19,11 @@
import sys
+import absl
import tensorflow
import tensorflow.python.debug.examples.v1.debug_mnist_v1 as debug_mnist_v1
+import tensorflow.python.debug.examples.v2.debug_mnist_v2 as debug_mnist_v2
tf = tensorflow.compat.v1
@@ -34,7 +36,9 @@
with tf.Graph().as_default():
tf.app.run(main=debug_mnist_v1.main, argv=[sys.argv[0]] + unparsed)
else:
- tf.logging.info("tfdbg is not implemented in TensorFlow v2 yet")
+ flags, unparsed = debug_mnist_v2.parse_args()
+ debug_mnist_v2.FLAGS = flags
+ absl.app.run(main=debug_mnist_v2.main, argv=[sys.argv[0]] + unparsed)
if __name__ == "__main__":
diff --git a/tensorflow/python/debug/examples/v2/debug_fibonacci_v2.py b/tensorflow/python/debug/examples/v2/debug_fibonacci_v2.py
new file mode 100644
index 0000000..6e0eed1
--- /dev/null
+++ b/tensorflow/python/debug/examples/v2/debug_fibonacci_v2.py
@@ -0,0 +1,89 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Demo of the tfdbg curses UI: A TF v2 network computing Fibonacci sequence."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+import absl
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+import tensorflow.compat.v2 as tf
+
+FLAGS = None
+
+tf.compat.v1.enable_v2_behavior()
+
+
+def main(_):
+ # Wrap the TensorFlow Session object for debugging.
+ # TODO(anthonyjliu): Enable debugger from flags
+ if FLAGS.debug and FLAGS.tensorboard_debug_address:
+ raise ValueError(
+ "The --debug and --tensorboard_debug_address flags are mutually "
+ "exclusive.")
+ if FLAGS.debug:
+ raise NotImplementedError(
+ "tfdbg v2 support for debug_fibonacci is not implemented yet")
+ elif FLAGS.tensorboard_debug_address:
+ raise NotImplementedError(
+ "tfdbg v2 support for debug_fibonacci is not implemented yet")
+
+ # Construct the TensorFlow network.
+ n0 = tf.constant(np.ones([FLAGS.tensor_size] * 2), dtype=tf.int32)
+ n1 = tf.constant(np.ones([FLAGS.tensor_size] * 2), dtype=tf.int32)
+
+ for _ in xrange(2, FLAGS.length):
+ n0, n1 = n1, tf.add(n0, n1)
+
+ print("Fibonacci number at position %d:\n%s" % (FLAGS.length, n1.numpy()))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.register("type", "bool", lambda v: v.lower() == "true")
+ parser.add_argument(
+ "--tensor_size",
+ type=int,
+ default=1,
+ help="""\
+ Size of tensor. E.g., if the value is 30, the tensors will have shape
+ [30, 30].\
+ """)
+ parser.add_argument(
+ "--length",
+ type=int,
+ default=20,
+ help="Length of the fibonacci sequence to compute.")
+ parser.add_argument(
+ "--debug",
+ dest="debug",
+ action="store_true",
+ help="Use TensorFlow Debugger (tfdbg). Mutually exclusive with the "
+ "--tensorboard_debug_address flag.")
+ parser.add_argument(
+ "--tensorboard_debug_address",
+ type=str,
+ default=None,
+ help="Connect to the TensorBoard Debugger Plugin backend specified by "
+ "the gRPC address (e.g., localhost:1234). Mutually exclusive with the "
+ "--debug flag.")
+
+ FLAGS, unparsed = parser.parse_known_args()
+
+ absl.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/debug/examples/v2/debug_mnist_v2.py b/tensorflow/python/debug/examples/v2/debug_mnist_v2.py
new file mode 100644
index 0000000..b799f7f
--- /dev/null
+++ b/tensorflow/python/debug/examples/v2/debug_mnist_v2.py
@@ -0,0 +1,216 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Demo of the tfdbg curses CLI: Locating the source of bad numerical values with TF v2.
+
+This demo contains a classical example of a neural network for the mnist
+dataset, but modifications are made so that problematic numerical values (infs
+and nans) appear in nodes of the graph during training.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+import absl
+import tensorflow.compat.v2 as tf
+
+IMAGE_SIZE = 28
+HIDDEN_SIZE = 500
+NUM_LABELS = 10
+
+# If we set the weights randomly, the model will converge normally about half
+# the time. We need a seed to ensure that the bad numerical values issue
+# appears.
+RAND_SEED = 42
+
+tf.compat.v1.enable_v2_behavior()
+
+FLAGS = None
+
+
+def parse_args():
+ """Parses commandline arguments.
+
+ Returns:
+ A tuple (parsed, unparsed) of the parsed object and a group of unparsed
+ arguments that did not match the parser.
+ """
+ parser = argparse.ArgumentParser()
+ parser.register("type", "bool", lambda v: v.lower() == "true")
+ parser.add_argument(
+ "--max_steps",
+ type=int,
+ default=10,
+ help="Number of steps to run trainer.")
+ parser.add_argument(
+ "--train_batch_size",
+ type=int,
+ default=100,
+ help="Batch size used during training.")
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=0.025,
+ help="Initial learning rate.")
+ parser.add_argument(
+ "--data_dir",
+ type=str,
+ default="/tmp/mnist_data",
+ help="Directory for storing data")
+ parser.add_argument(
+ "--fake_data",
+ type="bool",
+ nargs="?",
+ const=True,
+ default=False,
+ help="Use fake MNIST data for unit testing")
+ parser.add_argument(
+ "--debug",
+ type="bool",
+ nargs="?",
+ const=True,
+ default=False,
+ help="Use debugger to track down bad values during training. "
+ "Mutually exclusive with the --tensorboard_debug_address flag.")
+ parser.add_argument(
+ "--tensorboard_debug_address",
+ type=str,
+ default=None,
+ help="Connect to the TensorBoard Debugger Plugin backend specified by "
+ "the gRPC address (e.g., localhost:1234). Mutually exclusive with the "
+ "--debug flag.")
+ parser.add_argument(
+ "--use_random_config_path",
+ type="bool",
+ nargs="?",
+ const=True,
+ default=False,
+ help="""If set, set config file path to a random file in the temporary
+ directory.""")
+ return parser.parse_known_args()
+
+
+def main(_):
+ # TODO(anthonyjliu): Enable debugger from flags
+ if FLAGS.debug and FLAGS.tensorboard_debug_address:
+ raise ValueError(
+ "The --debug and --tensorboard_debug_address flags are mutually "
+ "exclusive.")
+ if FLAGS.debug:
+ raise NotImplementedError(
+ "tfdbg v2 support for debug_mnist is not implemented yet")
+ elif FLAGS.tensorboard_debug_address:
+ raise NotImplementedError(
+ "tfdbg v2 support for debug_mnist is not implemented yet")
+
+ # Import data
+ if FLAGS.fake_data:
+ imgs = tf.random.uniform(maxval=256, shape=(10, 28, 28), dtype=tf.int32)
+ labels = tf.random.uniform(maxval=10, shape=(10,), dtype=tf.int32)
+ mnist_train = imgs, labels
+ mnist_test = imgs, labels
+ else:
+ mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()
+
+ @tf.function
+ def format_example(imgs, labels):
+ """Formats each training and test example to work with our model."""
+ imgs = tf.reshape(imgs, [-1, 28 * 28])
+ imgs = tf.cast(imgs, tf.float32) / 255.0
+ labels = tf.one_hot(labels, depth=10, dtype=tf.float32)
+ return imgs, labels
+
+ train_ds = tf.data.Dataset.from_tensor_slices(mnist_train).shuffle(
+ FLAGS.train_batch_size * FLAGS.max_steps,
+ seed=RAND_SEED).batch(FLAGS.train_batch_size)
+ train_ds = train_ds.map(format_example)
+
+ test_ds = tf.data.Dataset.from_tensor_slices(mnist_test).repeat().batch(
+ len(mnist_test[0]))
+ test_ds = test_ds.map(format_example)
+
+ def get_dense_weights(input_dim, output_dim):
+ """Initializes the parameters for a single dense layer."""
+ initial_kernel = tf.keras.initializers.TruncatedNormal(
+ mean=0.0, stddev=0.1, seed=RAND_SEED)
+ kernel = tf.Variable(initial_kernel([input_dim, output_dim]))
+ bias = tf.Variable(tf.constant(0.1, shape=[output_dim]))
+
+ return kernel, bias
+
+ @tf.function
+ def dense_layer(weights, input_tensor, act=tf.nn.relu):
+ """Runs the forward computation for a single dense layer."""
+ kernel, bias = weights
+ preactivate = tf.matmul(input_tensor, kernel) + bias
+
+ activations = act(preactivate)
+ return activations
+
+ # init model
+ hidden = get_dense_weights(IMAGE_SIZE**2, HIDDEN_SIZE)
+ logits = get_dense_weights(HIDDEN_SIZE, NUM_LABELS)
+ variables = hidden + logits
+
+ @tf.function
+ def model(x):
+ """Feed forward function of the model.
+
+ Args:
+ x: a (?, 28*28) tensor consisting of the feature inputs for a batch of
+ examples.
+
+ Returns:
+ A (?, 10) tensor containing the class scores for each example.
+ """
+ hidden_act = dense_layer(hidden, x)
+ logits_act = dense_layer(logits, hidden_act, tf.identity)
+ y = tf.nn.softmax(logits_act)
+ return y
+
+ @tf.function
+ def loss(logits, labels):
+ """Calculates cross entropy loss."""
+ diff = -(labels * tf.math.log(logits))
+ loss = tf.reduce_mean(diff)
+ return loss
+
+ train_batches = iter(train_ds)
+ test_batches = iter(test_ds)
+ optimizer = tf.optimizers.Adam(learning_rate=FLAGS.learning_rate)
+ for i in range(FLAGS.max_steps):
+ x_train, y_train = next(train_batches)
+ x_test, y_test = next(test_batches)
+
+ # Train Step
+ with tf.GradientTape() as tape:
+ y = model(x_train)
+ loss_val = loss(y, y_train)
+ grads = tape.gradient(loss_val, variables)
+
+ optimizer.apply_gradients(zip(grads, variables))
+
+ # Evaluation Step
+ y = model(x_test)
+ correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_test, 1))
+ accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
+ print("Accuracy at step %d: %s" % (i, accuracy.numpy()))
+
+
+if __name__ == "__main__":
+ FLAGS, unparsed = parse_args()
+ absl.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/debug/examples/v2/examples_v2_test.sh b/tensorflow/python/debug/examples/v2/examples_v2_test.sh
new file mode 100755
index 0000000..6ed84d2
--- /dev/null
+++ b/tensorflow/python/debug/examples/v2/examples_v2_test.sh
@@ -0,0 +1,68 @@
+#!/bin/bash
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Bash unit tests for TensorFlow Debugger (tfdbg) Python examples that do not
+# involve downloading data. Also tests the binary offline_analyzer.
+#
+# Command-line flags:
+# --virtualenv: (optional) If set, will test the examples and binaries
+# against pip install of TensorFlow in a virtualenv.
+
+set -e
+
+# Filter out LOG(INFO)
+export TF_CPP_MIN_LOG_LEVEL=1
+
+IS_VIRTUALENV=0
+PYTHON_BIN_PATH=""
+while true; do
+ if [[ -z "$1" ]]; then
+ break
+ elif [[ "$1" == "--virtualenv" ]]; then
+ IS_VIRTUALENV=1
+ PYTHON_BIN_PATH=$(which python)
+ echo
+ echo "IS_VIRTUALENV = ${IS_VIRTUALENV}"
+ echo "PYTHON_BIN_PATH = ${PYTHON_BIN_PATH}"
+ echo "Will test tfdbg examples and binaries against virtualenv pip install."
+ echo
+ fi
+ shift 1
+done
+
+if [[ -z "${PYTHON_BIN_PATH}" ]]; then
+ DEBUG_FIBONACCI_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_fibonacci_v2"
+ DEBUG_MNIST_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_mnist_v2"
+else
+ DEBUG_FIBONACCI_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.v2.debug_fibonacci"
+ DEBUG_MNIST_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.v2.debug_mnist"
+fi
+
+# Override the default ui_type=curses to allow the test to pass in a tty-less
+# test environment.
+cat << EOF | ${DEBUG_FIBONACCI_BIN} --tensor_size=2
+run
+exit
+EOF
+
+cat << EOF | ${DEBUG_MNIST_BIN} --max_steps=1 --fake_data
+run -t 1
+run --node_name_filter hidden --op_type_filter MatMul
+run -f has_inf_or_nan
+EOF
+
+echo
+echo "SUCCESS: tfdbg examples and binaries test PASSED"