[tfdbg] Add mnist and fib examples for tensorflow v2

Both examples replicate the same behavior of the tensorflow
v1 examples.

PiperOrigin-RevId: 269385509
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 4e4756b..988bfc7 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -51,6 +51,7 @@
         ":source_remote",
     ] + if_not_windows([
         ":debug_examples_v1",
+        ":debug_examples_v2",
     ]),
 )
 
@@ -432,6 +433,14 @@
     ]),
 )
 
+py_library(
+    name = "debug_examples_v2",
+    deps = [
+        ":debug_fibonacci_lib",
+        ":debug_mnist_lib",
+    ],
+)
+
 py_binary(
     name = "debug_fibonacci",
     srcs = ["examples/v1/debug_fibonacci.py"],
@@ -440,9 +449,20 @@
     deps = [":debug_fibonacci_lib"],
 )
 
+py_binary(
+    name = "debug_fibonacci_v2",
+    srcs = ["examples/v2/debug_fibonacci_v2.py"],
+    python_version = "PY2",
+    srcs_version = "PY2AND3",
+    deps = [":debug_fibonacci_lib"],
+)
+
 py_library(
     name = "debug_fibonacci_lib",
-    srcs = ["examples/v1/debug_fibonacci.py"],
+    srcs = [
+        "examples/v1/debug_fibonacci.py",
+        "examples/v2/debug_fibonacci_v2.py",
+    ],
     srcs_version = "PY2AND3",
     deps = [
         ":debug_py",
@@ -487,11 +507,20 @@
     deps = [":debug_mnist_lib"],
 )
 
+py_binary(
+    name = "debug_mnist_v2",
+    srcs = ["examples/v2/debug_mnist_v2.py"],
+    python_version = "PY2",
+    srcs_version = "PY2AND3",
+    deps = [":debug_mnist_lib"],
+)
+
 py_library(
     name = "debug_mnist_lib",
     srcs = [
         "examples/debug_mnist.py",
         "examples/v1/debug_mnist_v1.py",
+        "examples/v2/debug_mnist_v2.py",
     ],
     srcs_version = "PY2AND3",
     deps = [
@@ -1231,3 +1260,16 @@
         "no_windows",
     ],
 )
+
+sh_test(
+    name = "examples_v2_test",
+    size = "medium",
+    srcs = ["examples/v2/examples_v2_test.sh"],
+    data = [
+        ":debug_fibonacci_v2",
+        ":debug_mnist_v2",
+    ],
+    tags = [
+        "no_windows",
+    ],
+)
diff --git a/tensorflow/python/debug/examples/debug_mnist.py b/tensorflow/python/debug/examples/debug_mnist.py
index 23e307e..c7e51d1e 100644
--- a/tensorflow/python/debug/examples/debug_mnist.py
+++ b/tensorflow/python/debug/examples/debug_mnist.py
@@ -19,9 +19,11 @@
 
 import sys
 
+import absl
 import tensorflow
 
 import tensorflow.python.debug.examples.v1.debug_mnist_v1 as debug_mnist_v1
+import tensorflow.python.debug.examples.v2.debug_mnist_v2 as debug_mnist_v2
 
 tf = tensorflow.compat.v1
 
@@ -34,7 +36,9 @@
     with tf.Graph().as_default():
       tf.app.run(main=debug_mnist_v1.main, argv=[sys.argv[0]] + unparsed)
   else:
-    tf.logging.info("tfdbg is not implemented in TensorFlow v2 yet")
+    flags, unparsed = debug_mnist_v2.parse_args()
+    debug_mnist_v2.FLAGS = flags
+    absl.app.run(main=debug_mnist_v2.main, argv=[sys.argv[0]] + unparsed)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/debug/examples/v2/debug_fibonacci_v2.py b/tensorflow/python/debug/examples/v2/debug_fibonacci_v2.py
new file mode 100644
index 0000000..6e0eed1
--- /dev/null
+++ b/tensorflow/python/debug/examples/v2/debug_fibonacci_v2.py
@@ -0,0 +1,89 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Demo of the tfdbg curses UI: A TF v2 network computing Fibonacci sequence."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+import absl
+import numpy as np
+from six.moves import xrange  # pylint: disable=redefined-builtin
+import tensorflow.compat.v2 as tf
+
+FLAGS = None
+
+tf.compat.v1.enable_v2_behavior()
+
+
+def main(_):
+  # Wrap the TensorFlow Session object for debugging.
+  # TODO(anthonyjliu): Enable debugger from flags
+  if FLAGS.debug and FLAGS.tensorboard_debug_address:
+    raise ValueError(
+        "The --debug and --tensorboard_debug_address flags are mutually "
+        "exclusive.")
+  if FLAGS.debug:
+    raise NotImplementedError(
+        "tfdbg v2 support for debug_fibonacci is not implemented yet")
+  elif FLAGS.tensorboard_debug_address:
+    raise NotImplementedError(
+        "tfdbg v2 support for debug_fibonacci is not implemented yet")
+
+  # Construct the TensorFlow network.
+  n0 = tf.constant(np.ones([FLAGS.tensor_size] * 2), dtype=tf.int32)
+  n1 = tf.constant(np.ones([FLAGS.tensor_size] * 2), dtype=tf.int32)
+
+  for _ in xrange(2, FLAGS.length):
+    n0, n1 = n1, tf.add(n0, n1)
+
+  print("Fibonacci number at position %d:\n%s" % (FLAGS.length, n1.numpy()))
+
+
+if __name__ == "__main__":
+  parser = argparse.ArgumentParser()
+  parser.register("type", "bool", lambda v: v.lower() == "true")
+  parser.add_argument(
+      "--tensor_size",
+      type=int,
+      default=1,
+      help="""\
+      Size of tensor. E.g., if the value is 30, the tensors will have shape
+      [30, 30].\
+      """)
+  parser.add_argument(
+      "--length",
+      type=int,
+      default=20,
+      help="Length of the fibonacci sequence to compute.")
+  parser.add_argument(
+      "--debug",
+      dest="debug",
+      action="store_true",
+      help="Use TensorFlow Debugger (tfdbg). Mutually exclusive with the "
+      "--tensorboard_debug_address flag.")
+  parser.add_argument(
+      "--tensorboard_debug_address",
+      type=str,
+      default=None,
+      help="Connect to the TensorBoard Debugger Plugin backend specified by "
+      "the gRPC address (e.g., localhost:1234). Mutually exclusive with the "
+      "--debug flag.")
+
+  FLAGS, unparsed = parser.parse_known_args()
+
+  absl.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/debug/examples/v2/debug_mnist_v2.py b/tensorflow/python/debug/examples/v2/debug_mnist_v2.py
new file mode 100644
index 0000000..b799f7f
--- /dev/null
+++ b/tensorflow/python/debug/examples/v2/debug_mnist_v2.py
@@ -0,0 +1,216 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Demo of the tfdbg curses CLI: Locating the source of bad numerical values with TF v2.
+
+This demo contains a classical example of a neural network for the mnist
+dataset, but modifications are made so that problematic numerical values (infs
+and nans) appear in nodes of the graph during training.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+import absl
+import tensorflow.compat.v2 as tf
+
+IMAGE_SIZE = 28
+HIDDEN_SIZE = 500
+NUM_LABELS = 10
+
+# If we set the weights randomly, the model will converge normally about half
+# the time. We need a seed to ensure that the bad numerical values issue
+# appears.
+RAND_SEED = 42
+
+tf.compat.v1.enable_v2_behavior()
+
+FLAGS = None
+
+
+def parse_args():
+  """Parses commandline arguments.
+
+  Returns:
+    A tuple (parsed, unparsed) of the parsed object and a group of unparsed
+      arguments that did not match the parser.
+  """
+  parser = argparse.ArgumentParser()
+  parser.register("type", "bool", lambda v: v.lower() == "true")
+  parser.add_argument(
+      "--max_steps",
+      type=int,
+      default=10,
+      help="Number of steps to run trainer.")
+  parser.add_argument(
+      "--train_batch_size",
+      type=int,
+      default=100,
+      help="Batch size used during training.")
+  parser.add_argument(
+      "--learning_rate",
+      type=float,
+      default=0.025,
+      help="Initial learning rate.")
+  parser.add_argument(
+      "--data_dir",
+      type=str,
+      default="/tmp/mnist_data",
+      help="Directory for storing data")
+  parser.add_argument(
+      "--fake_data",
+      type="bool",
+      nargs="?",
+      const=True,
+      default=False,
+      help="Use fake MNIST data for unit testing")
+  parser.add_argument(
+      "--debug",
+      type="bool",
+      nargs="?",
+      const=True,
+      default=False,
+      help="Use debugger to track down bad values during training. "
+      "Mutually exclusive with the --tensorboard_debug_address flag.")
+  parser.add_argument(
+      "--tensorboard_debug_address",
+      type=str,
+      default=None,
+      help="Connect to the TensorBoard Debugger Plugin backend specified by "
+      "the gRPC address (e.g., localhost:1234). Mutually exclusive with the "
+      "--debug flag.")
+  parser.add_argument(
+      "--use_random_config_path",
+      type="bool",
+      nargs="?",
+      const=True,
+      default=False,
+      help="""If set, set config file path to a random file in the temporary
+      directory.""")
+  return parser.parse_known_args()
+
+
+def main(_):
+  # TODO(anthonyjliu): Enable debugger from flags
+  if FLAGS.debug and FLAGS.tensorboard_debug_address:
+    raise ValueError(
+        "The --debug and --tensorboard_debug_address flags are mutually "
+        "exclusive.")
+  if FLAGS.debug:
+    raise NotImplementedError(
+        "tfdbg v2 support for debug_mnist is not implemented yet")
+  elif FLAGS.tensorboard_debug_address:
+    raise NotImplementedError(
+        "tfdbg v2 support for debug_mnist is not implemented yet")
+
+  # Import data
+  if FLAGS.fake_data:
+    imgs = tf.random.uniform(maxval=256, shape=(10, 28, 28), dtype=tf.int32)
+    labels = tf.random.uniform(maxval=10, shape=(10,), dtype=tf.int32)
+    mnist_train = imgs, labels
+    mnist_test = imgs, labels
+  else:
+    mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()
+
+  @tf.function
+  def format_example(imgs, labels):
+    """Formats each training and test example to work with our model."""
+    imgs = tf.reshape(imgs, [-1, 28 * 28])
+    imgs = tf.cast(imgs, tf.float32) / 255.0
+    labels = tf.one_hot(labels, depth=10, dtype=tf.float32)
+    return imgs, labels
+
+  train_ds = tf.data.Dataset.from_tensor_slices(mnist_train).shuffle(
+      FLAGS.train_batch_size * FLAGS.max_steps,
+      seed=RAND_SEED).batch(FLAGS.train_batch_size)
+  train_ds = train_ds.map(format_example)
+
+  test_ds = tf.data.Dataset.from_tensor_slices(mnist_test).repeat().batch(
+      len(mnist_test[0]))
+  test_ds = test_ds.map(format_example)
+
+  def get_dense_weights(input_dim, output_dim):
+    """Initializes the parameters for a single dense layer."""
+    initial_kernel = tf.keras.initializers.TruncatedNormal(
+        mean=0.0, stddev=0.1, seed=RAND_SEED)
+    kernel = tf.Variable(initial_kernel([input_dim, output_dim]))
+    bias = tf.Variable(tf.constant(0.1, shape=[output_dim]))
+
+    return kernel, bias
+
+  @tf.function
+  def dense_layer(weights, input_tensor, act=tf.nn.relu):
+    """Runs the forward computation for a single dense layer."""
+    kernel, bias = weights
+    preactivate = tf.matmul(input_tensor, kernel) + bias
+
+    activations = act(preactivate)
+    return activations
+
+  # init model
+  hidden = get_dense_weights(IMAGE_SIZE**2, HIDDEN_SIZE)
+  logits = get_dense_weights(HIDDEN_SIZE, NUM_LABELS)
+  variables = hidden + logits
+
+  @tf.function
+  def model(x):
+    """Feed forward function of the model.
+
+    Args:
+      x: a (?, 28*28) tensor consisting of the feature inputs for a batch of
+        examples.
+
+    Returns:
+      A (?, 10) tensor containing the class scores for each example.
+    """
+    hidden_act = dense_layer(hidden, x)
+    logits_act = dense_layer(logits, hidden_act, tf.identity)
+    y = tf.nn.softmax(logits_act)
+    return y
+
+  @tf.function
+  def loss(logits, labels):
+    """Calculates cross entropy loss."""
+    diff = -(labels * tf.math.log(logits))
+    loss = tf.reduce_mean(diff)
+    return loss
+
+  train_batches = iter(train_ds)
+  test_batches = iter(test_ds)
+  optimizer = tf.optimizers.Adam(learning_rate=FLAGS.learning_rate)
+  for i in range(FLAGS.max_steps):
+    x_train, y_train = next(train_batches)
+    x_test, y_test = next(test_batches)
+
+    # Train Step
+    with tf.GradientTape() as tape:
+      y = model(x_train)
+      loss_val = loss(y, y_train)
+    grads = tape.gradient(loss_val, variables)
+
+    optimizer.apply_gradients(zip(grads, variables))
+
+    # Evaluation Step
+    y = model(x_test)
+    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_test, 1))
+    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
+    print("Accuracy at step %d: %s" % (i, accuracy.numpy()))
+
+
+if __name__ == "__main__":
+  FLAGS, unparsed = parse_args()
+  absl.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/debug/examples/v2/examples_v2_test.sh b/tensorflow/python/debug/examples/v2/examples_v2_test.sh
new file mode 100755
index 0000000..6ed84d2
--- /dev/null
+++ b/tensorflow/python/debug/examples/v2/examples_v2_test.sh
@@ -0,0 +1,68 @@
+#!/bin/bash
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Bash unit tests for TensorFlow Debugger (tfdbg) Python examples that do not
+# involve downloading data. Also tests the binary offline_analyzer.
+#
+# Command-line flags:
+#   --virtualenv: (optional) If set, will test the examples and binaries
+#     against pip install of TensorFlow in a virtualenv.
+
+set -e
+
+# Filter out LOG(INFO)
+export TF_CPP_MIN_LOG_LEVEL=1
+
+IS_VIRTUALENV=0
+PYTHON_BIN_PATH=""
+while true; do
+  if [[ -z "$1" ]]; then
+    break
+  elif [[ "$1" == "--virtualenv" ]]; then
+    IS_VIRTUALENV=1
+    PYTHON_BIN_PATH=$(which python)
+    echo
+    echo "IS_VIRTUALENV = ${IS_VIRTUALENV}"
+    echo "PYTHON_BIN_PATH = ${PYTHON_BIN_PATH}"
+    echo "Will test tfdbg examples and binaries against virtualenv pip install."
+    echo
+  fi
+  shift 1
+done
+
+if [[ -z "${PYTHON_BIN_PATH}" ]]; then
+  DEBUG_FIBONACCI_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_fibonacci_v2"
+  DEBUG_MNIST_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_mnist_v2"
+else
+  DEBUG_FIBONACCI_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.v2.debug_fibonacci"
+  DEBUG_MNIST_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.v2.debug_mnist"
+fi
+
+# Override the default ui_type=curses to allow the test to pass in a tty-less
+# test environment.
+cat << EOF | ${DEBUG_FIBONACCI_BIN} --tensor_size=2
+run
+exit
+EOF
+
+cat << EOF | ${DEBUG_MNIST_BIN} --max_steps=1 --fake_data
+run -t 1
+run --node_name_filter hidden --op_type_filter MatMul
+run -f has_inf_or_nan
+EOF
+
+echo
+echo "SUCCESS: tfdbg examples and binaries test PASSED"