Automated rollback of commit c1de94204eb0e49b98eb644daf0584856a8a9db3

PiperOrigin-RevId: 270922181
diff --git a/tensorflow/python/eager/benchmarks/resnet50/BUILD b/tensorflow/python/eager/benchmarks/resnet50/BUILD
new file mode 100644
index 0000000..469105e
--- /dev/null
+++ b/tensorflow/python/eager/benchmarks/resnet50/BUILD
@@ -0,0 +1,62 @@
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
+package(
+    default_visibility = ["//tensorflow:internal"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+py_library(
+    name = "resnet50",
+    srcs = ["resnet50.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow:tensorflow_py_no_contrib",
+    ],
+)
+
+py_library(
+    name = "resnet50_test_lib",
+    srcs = ["resnet50_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":resnet50",
+        "//tensorflow:tensorflow_py_no_contrib",
+    ],
+)
+
+cuda_py_test(
+    name = "resnet50_test",
+    size = "medium",
+    srcs = ["resnet50_test.py"],
+    additional_deps = [
+        ":resnet50",
+        "//tensorflow:tensorflow_py_no_contrib",
+    ],
+    shard_count = 4,
+    tags = [
+        "no_windows",  # needs investigation
+        "optonly",
+        "oss_serial",
+    ],
+)
+
+cuda_py_test(
+    name = "resnet50_graph_test",
+    size = "medium",
+    srcs = ["resnet50_graph_test.py"],
+    additional_deps = [
+        ":resnet50",
+        ":resnet50_test_lib",
+        "//third_party/py/numpy",
+        "//tensorflow:tensorflow_py_no_contrib",
+    ],
+    shard_count = 4,
+    tags = [
+        "no_windows",  # needs investigation
+        "noasan",
+        "nomsan",
+        "notsan",
+        "optonly",
+        "oss_serial",
+    ],
+)
diff --git a/tensorflow/python/eager/benchmarks/resnet50/README.md b/tensorflow/python/eager/benchmarks/resnet50/README.md
new file mode 100644
index 0000000..79e4600
--- /dev/null
+++ b/tensorflow/python/eager/benchmarks/resnet50/README.md
@@ -0,0 +1,45 @@
+Image classification using the ResNet50 model described in
+[Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385).
+
+Contents:
+
+- `resnet50.py`: Model definition
+- `resnet50_test.py`: Sanity unittests and benchmarks for using the model with
+  eager execution enabled.
+- `resnet50_graph_test.py`: Sanity unittests and benchmarks when using the same
+  model code to construct a TensorFlow graph.
+
+# Benchmarks
+
+Using a synthetic data, run:
+
+```
+# Using eager execution
+python resnet50_test.py --benchmarks=.
+
+# Using graph execution
+python resnet50_graph_test.py --benchmarks=.
+```
+
+The above uses the model definition included with the TensorFlow pip
+package. To build (and run benchmarks) from source:
+
+```
+# Using eager execution
+bazel run -c opt --config=cuda :resnet50_test -- --benchmarks=.
+
+# Using graph execution
+bazel run -c opt --config=cuda :resnet50_graph_test -- --benchmarks=.
+```
+
+(Or remove the `--config=cuda` flag for running on CPU instead of GPU).
+
+On October 31, 2017, the benchmarks demonstrated comparable performance
+for eager and graph execution of this particular model when using
+a single NVIDIA Titan X (Pascal) GPU on a host with an
+Intel Xeon E5-1650 CPU @ 3.50GHz and a batch size of 32.
+
+| Benchmark name                           | batch size    | images/second |
+| ---------------------------------------  | ------------- | ------------- |
+| eager_train_gpu_batch_32_channels_first  |            32 |           171 |
+| graph_train_gpu_batch_32_channels_first  |            32 |           172 |
diff --git a/tensorflow/python/eager/benchmarks/resnet50/resnet50.py b/tensorflow/python/eager/benchmarks/resnet50/resnet50.py
new file mode 100644
index 0000000..9d090e8
--- /dev/null
+++ b/tensorflow/python/eager/benchmarks/resnet50/resnet50.py
@@ -0,0 +1,308 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ResNet50 model definition compatible with TensorFlow's eager execution.
+
+Reference [Deep Residual Learning for Image
+Recognition](https://arxiv.org/abs/1512.03385)
+
+Adapted from tf.keras.applications.ResNet50. A notable difference is that the
+model here outputs logits while the Keras model outputs probability.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import tensorflow as tf
+
+layers = tf.keras.layers
+
+
+class _IdentityBlock(tf.keras.Model):
+  """_IdentityBlock is the block that has no conv layer at shortcut.
+
+  Args:
+    kernel_size: the kernel size of middle conv layer at main path
+    filters: list of integers, the filters of 3 conv layer at main path
+    stage: integer, current stage label, used for generating layer names
+    block: 'a','b'..., current block label, used for generating layer names
+    data_format: data_format for the input ('channels_first' or
+      'channels_last').
+  """
+
+  def __init__(self, kernel_size, filters, stage, block, data_format):
+    super(_IdentityBlock, self).__init__(name='')
+    filters1, filters2, filters3 = filters
+
+    conv_name_base = 'res' + str(stage) + block + '_branch'
+    bn_name_base = 'bn' + str(stage) + block + '_branch'
+    bn_axis = 1 if data_format == 'channels_first' else 3
+
+    self.conv2a = layers.Conv2D(
+        filters1, (1, 1), name=conv_name_base + '2a', data_format=data_format)
+    self.bn2a = layers.BatchNormalization(
+        axis=bn_axis, name=bn_name_base + '2a')
+
+    self.conv2b = layers.Conv2D(
+        filters2,
+        kernel_size,
+        padding='same',
+        data_format=data_format,
+        name=conv_name_base + '2b')
+    self.bn2b = layers.BatchNormalization(
+        axis=bn_axis, name=bn_name_base + '2b')
+
+    self.conv2c = layers.Conv2D(
+        filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format)
+    self.bn2c = layers.BatchNormalization(
+        axis=bn_axis, name=bn_name_base + '2c')
+
+  def call(self, input_tensor, training=False):
+    x = self.conv2a(input_tensor)
+    x = self.bn2a(x, training=training)
+    x = tf.nn.relu(x)
+
+    x = self.conv2b(x)
+    x = self.bn2b(x, training=training)
+    x = tf.nn.relu(x)
+
+    x = self.conv2c(x)
+    x = self.bn2c(x, training=training)
+
+    x += input_tensor
+    return tf.nn.relu(x)
+
+
+class _ConvBlock(tf.keras.Model):
+  """_ConvBlock is the block that has a conv layer at shortcut.
+
+  Args:
+      kernel_size: the kernel size of middle conv layer at main path
+      filters: list of integers, the filters of 3 conv layer at main path
+      stage: integer, current stage label, used for generating layer names
+      block: 'a','b'..., current block label, used for generating layer names
+      data_format: data_format for the input ('channels_first' or
+        'channels_last').
+      strides: strides for the convolution. Note that from stage 3, the first
+       conv layer at main path is with strides=(2,2), and the shortcut should
+       have strides=(2,2) as well.
+  """
+
+  def __init__(self,
+               kernel_size,
+               filters,
+               stage,
+               block,
+               data_format,
+               strides=(2, 2)):
+    super(_ConvBlock, self).__init__(name='')
+    filters1, filters2, filters3 = filters
+
+    conv_name_base = 'res' + str(stage) + block + '_branch'
+    bn_name_base = 'bn' + str(stage) + block + '_branch'
+    bn_axis = 1 if data_format == 'channels_first' else 3
+
+    self.conv2a = layers.Conv2D(
+        filters1, (1, 1),
+        strides=strides,
+        name=conv_name_base + '2a',
+        data_format=data_format)
+    self.bn2a = layers.BatchNormalization(
+        axis=bn_axis, name=bn_name_base + '2a')
+
+    self.conv2b = layers.Conv2D(
+        filters2,
+        kernel_size,
+        padding='same',
+        name=conv_name_base + '2b',
+        data_format=data_format)
+    self.bn2b = layers.BatchNormalization(
+        axis=bn_axis, name=bn_name_base + '2b')
+
+    self.conv2c = layers.Conv2D(
+        filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format)
+    self.bn2c = layers.BatchNormalization(
+        axis=bn_axis, name=bn_name_base + '2c')
+
+    self.conv_shortcut = layers.Conv2D(
+        filters3, (1, 1),
+        strides=strides,
+        name=conv_name_base + '1',
+        data_format=data_format)
+    self.bn_shortcut = layers.BatchNormalization(
+        axis=bn_axis, name=bn_name_base + '1')
+
+  def call(self, input_tensor, training=False):
+    x = self.conv2a(input_tensor)
+    x = self.bn2a(x, training=training)
+    x = tf.nn.relu(x)
+
+    x = self.conv2b(x)
+    x = self.bn2b(x, training=training)
+    x = tf.nn.relu(x)
+
+    x = self.conv2c(x)
+    x = self.bn2c(x, training=training)
+
+    shortcut = self.conv_shortcut(input_tensor)
+    shortcut = self.bn_shortcut(shortcut, training=training)
+
+    x += shortcut
+    return tf.nn.relu(x)
+
+
+# pylint: disable=not-callable
+class ResNet50(tf.keras.Model):
+  """Instantiates the ResNet50 architecture.
+
+  Args:
+    data_format: format for the image. Either 'channels_first' or
+      'channels_last'.  'channels_first' is typically faster on GPUs while
+      'channels_last' is typically faster on CPUs. See
+      https://www.tensorflow.org/performance/performance_guide#data_formats
+    name: Prefix applied to names of variables created in the model.
+    trainable: Is the model trainable? If true, performs backward
+        and optimization after call() method.
+    include_top: whether to include the fully-connected layer at the top of the
+      network.
+    pooling: Optional pooling mode for feature extraction when `include_top`
+      is `False`.
+      - `None` means that the output of the model will be the 4D tensor
+          output of the last convolutional layer.
+      - `avg` means that global average pooling will be applied to the output of
+          the last convolutional layer, and thus the output of the model will be
+          a 2D tensor.
+      - `max` means that global max pooling will be applied.
+    classes: optional number of classes to classify images into, only to be
+      specified if `include_top` is True.
+
+  Raises:
+      ValueError: in case of invalid argument for data_format.
+  """
+
+  def __init__(self,
+               data_format,
+               name='',
+               trainable=True,
+               include_top=True,
+               pooling=None,
+               classes=1000):
+    super(ResNet50, self).__init__(name=name)
+
+    valid_channel_values = ('channels_first', 'channels_last')
+    if data_format not in valid_channel_values:
+      raise ValueError('Unknown data_format: %s. Valid values: %s' %
+                       (data_format, valid_channel_values))
+    self.include_top = include_top
+
+    def conv_block(filters, stage, block, strides=(2, 2)):
+      return _ConvBlock(
+          3,
+          filters,
+          stage=stage,
+          block=block,
+          data_format=data_format,
+          strides=strides)
+
+    def id_block(filters, stage, block):
+      return _IdentityBlock(
+          3, filters, stage=stage, block=block, data_format=data_format)
+
+    self.conv1 = layers.Conv2D(
+        64, (7, 7),
+        strides=(2, 2),
+        data_format=data_format,
+        padding='same',
+        name='conv1')
+    bn_axis = 1 if data_format == 'channels_first' else 3
+    self.bn_conv1 = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')
+    self.max_pool = layers.MaxPooling2D(
+        (3, 3), strides=(2, 2), data_format=data_format)
+
+    self.l2a = conv_block([64, 64, 256], stage=2, block='a', strides=(1, 1))
+    self.l2b = id_block([64, 64, 256], stage=2, block='b')
+    self.l2c = id_block([64, 64, 256], stage=2, block='c')
+
+    self.l3a = conv_block([128, 128, 512], stage=3, block='a')
+    self.l3b = id_block([128, 128, 512], stage=3, block='b')
+    self.l3c = id_block([128, 128, 512], stage=3, block='c')
+    self.l3d = id_block([128, 128, 512], stage=3, block='d')
+
+    self.l4a = conv_block([256, 256, 1024], stage=4, block='a')
+    self.l4b = id_block([256, 256, 1024], stage=4, block='b')
+    self.l4c = id_block([256, 256, 1024], stage=4, block='c')
+    self.l4d = id_block([256, 256, 1024], stage=4, block='d')
+    self.l4e = id_block([256, 256, 1024], stage=4, block='e')
+    self.l4f = id_block([256, 256, 1024], stage=4, block='f')
+
+    self.l5a = conv_block([512, 512, 2048], stage=5, block='a')
+    self.l5b = id_block([512, 512, 2048], stage=5, block='b')
+    self.l5c = id_block([512, 512, 2048], stage=5, block='c')
+
+    self.avg_pool = layers.AveragePooling2D(
+        (7, 7), strides=(7, 7), data_format=data_format)
+
+    if self.include_top:
+      self.flatten = layers.Flatten()
+      self.fc1000 = layers.Dense(classes, name='fc1000')
+    else:
+      reduction_indices = [1, 2] if data_format == 'channels_last' else [2, 3]
+      reduction_indices = tf.constant(reduction_indices)
+      if pooling == 'avg':
+        self.global_pooling = functools.partial(
+            tf.reduce_mean,
+            reduction_indices=reduction_indices,
+            keep_dims=False)
+      elif pooling == 'max':
+        self.global_pooling = functools.partial(
+            tf.reduce_max, reduction_indices=reduction_indices, keep_dims=False)
+      else:
+        self.global_pooling = None
+
+  def call(self, inputs, training=True):
+    x = self.conv1(inputs)
+    x = self.bn_conv1(x, training=training)
+    x = tf.nn.relu(x)
+    x = self.max_pool(x)
+
+    x = self.l2a(x, training=training)
+    x = self.l2b(x, training=training)
+    x = self.l2c(x, training=training)
+
+    x = self.l3a(x, training=training)
+    x = self.l3b(x, training=training)
+    x = self.l3c(x, training=training)
+    x = self.l3d(x, training=training)
+
+    x = self.l4a(x, training=training)
+    x = self.l4b(x, training=training)
+    x = self.l4c(x, training=training)
+    x = self.l4d(x, training=training)
+    x = self.l4e(x, training=training)
+    x = self.l4f(x, training=training)
+
+    x = self.l5a(x, training=training)
+    x = self.l5b(x, training=training)
+    x = self.l5c(x, training=training)
+
+    x = self.avg_pool(x)
+
+    if self.include_top:
+      return self.fc1000(self.flatten(x))
+    elif self.global_pooling:
+      return self.global_pooling(x)
+    else:
+      return x
diff --git a/tensorflow/python/eager/benchmarks/resnet50/resnet50_graph_test.py b/tensorflow/python/eager/benchmarks/resnet50/resnet50_graph_test.py
new file mode 100644
index 0000000..9be7997
--- /dev/null
+++ b/tensorflow/python/eager/benchmarks/resnet50/resnet50_graph_test.py
@@ -0,0 +1,130 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests and benchmarks for ResNet50 under graph execution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+import tensorflow.compat.v1 as tf
+
+from tensorflow.python.eager.benchmarks.resnet50 import resnet50
+
+
+def data_format():
+  return 'channels_first' if tf.test.is_gpu_available() else 'channels_last'
+
+
+def image_shape(batch_size):
+  if data_format() == 'channels_first':
+    return [batch_size, 3, 224, 224]
+  return [batch_size, 224, 224, 3]
+
+
+def random_batch(batch_size):
+  images = np.random.rand(*image_shape(batch_size)).astype(np.float32)
+  num_classes = 1000
+  labels = np.random.randint(
+      low=0, high=num_classes, size=[batch_size]).astype(np.int32)
+  one_hot = np.zeros((batch_size, num_classes)).astype(np.float32)
+  one_hot[np.arange(batch_size), labels] = 1.
+  return images, one_hot
+
+
+class ResNet50GraphTest(tf.test.TestCase):
+
+  def testApply(self):
+    # Use small batches for tests because the OSS version runs
+    # in constrained GPU environment with 1-2GB of memory.
+    batch_size = 8
+    with tf.Graph().as_default():
+      images = tf.placeholder(tf.float32, image_shape(None))
+      model = resnet50.ResNet50(data_format())
+      predictions = model(images, training=False)
+
+      init = tf.global_variables_initializer()
+
+      with tf.Session() as sess:
+        sess.run(init)
+        np_images, _ = random_batch(batch_size)
+        out = sess.run(predictions, feed_dict={images: np_images})
+        self.assertAllEqual([batch_size, 1000], out.shape)
+
+
+class ResNet50Benchmarks(tf.test.Benchmark):
+
+  def _report(self, label, start, num_iters, batch_size):
+    avg_time = (time.time() - start) / num_iters
+    dev = 'gpu' if tf.test.is_gpu_available() else 'cpu'
+    name = 'graph_%s_%s_batch_%d_%s' % (label, dev, batch_size, data_format())
+    extras = {'examples_per_sec': batch_size / avg_time}
+    self.report_benchmark(
+        iters=num_iters, wall_time=avg_time, name=name, extras=extras)
+
+  def benchmark_graph_apply(self):
+    with tf.Graph().as_default():
+      images = tf.placeholder(tf.float32, image_shape(None))
+      model = resnet50.ResNet50(data_format())
+      predictions = model(images, training=False)
+
+      init = tf.global_variables_initializer()
+
+      batch_size = 64
+      with tf.Session() as sess:
+        sess.run(init)
+        np_images, _ = random_batch(batch_size)
+        num_burn, num_iters = (3, 30)
+        for _ in range(num_burn):
+          sess.run(predictions, feed_dict={images: np_images})
+        start = time.time()
+        for _ in range(num_iters):
+          # Comparison with the eager execution benchmark in resnet50_test.py
+          # isn't entirely fair as the time here includes the cost of copying
+          # the feeds from CPU memory to GPU.
+          sess.run(predictions, feed_dict={images: np_images})
+        self._report('apply', start, num_iters, batch_size)
+
+  def benchmark_graph_train(self):
+    for batch_size in [16, 32, 64]:
+      with tf.Graph().as_default():
+        np_images, np_labels = random_batch(batch_size)
+        dataset = tf.data.Dataset.from_tensors((np_images, np_labels)).repeat()
+        images, labels = tf.compat.v1.data.make_one_shot_iterator(
+            dataset).get_next()
+
+        model = resnet50.ResNet50(data_format())
+        logits = model(images, training=True)
+        loss = tf.losses.softmax_cross_entropy(
+            logits=logits, onehot_labels=labels)
+        optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
+        train_op = optimizer.minimize(loss)
+
+        init = tf.global_variables_initializer()
+        with tf.Session() as sess:
+          sess.run(init)
+          (num_burn, num_iters) = (5, 10)
+          for _ in range(num_burn):
+            sess.run(train_op)
+          start = time.time()
+          for _ in range(num_iters):
+            sess.run(train_op)
+          self._report('train', start, num_iters, batch_size)
+
+
+if __name__ == '__main__':
+  tf.test.main()
diff --git a/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py b/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py
new file mode 100644
index 0000000..931c115
--- /dev/null
+++ b/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py
@@ -0,0 +1,377 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests and benchmarks for the ResNet50 model, executed eagerly."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gc
+import os
+import tempfile
+import time
+
+from six.moves import xrange  # pylint: disable=redefined-builtin
+import tensorflow as tf
+
+from tensorflow.python.client import device_lib
+from tensorflow.python.eager import context
+from tensorflow.python.eager import tape
+from tensorflow.python.eager.benchmarks.resnet50 import resnet50
+
+
+def device_and_data_format():
+  if tf.config.experimental.list_physical_devices('GPU'):
+    return ('/gpu:0', 'channels_first')
+  return ('/cpu:0', 'channels_last')
+
+
+def random_batch(batch_size, data_format):
+  shape = (3, 224, 224) if data_format == 'channels_first' else (224, 224, 3)
+  shape = (batch_size,) + shape
+
+  num_classes = 1000
+  images = tf.random_uniform(shape)
+  labels = tf.random_uniform(
+      [batch_size], minval=0, maxval=num_classes, dtype=tf.int32)
+  one_hot = tf.one_hot(labels, num_classes)
+
+  return images, one_hot
+
+
+def compute_gradients(model, images, labels, num_replicas=1):
+  with tf.GradientTape() as grad_tape:
+    logits = model(images, training=True)
+    loss = tf.losses.softmax_cross_entropy(
+        logits=logits, onehot_labels=labels)
+    tf.compat.v2.summary.write('loss', loss)
+    if num_replicas != 1:
+      loss /= num_replicas
+
+  # TODO(b/110991947): We can mistakenly trace the gradient call in
+  # multi-threaded environment. Explicitly disable recording until
+  # this is fixed.
+  with tape.stop_recording():
+    grads = grad_tape.gradient(loss, model.variables)
+  return grads
+
+
+def apply_gradients(model, optimizer, gradients):
+  optimizer.apply_gradients(zip(gradients, model.variables))
+
+
+def _events_from_file(filepath):
+  """Returns all events in a single event file.
+
+  Args:
+    filepath: Path to the event file.
+
+  Returns:
+    A list of all tf.compat.v1.Event protos in the event file.
+  """
+  records = list(tf.python_io.tf_record_iterator(filepath))
+  result = []
+  for r in records:
+    event = tf.Event()
+    event.ParseFromString(r)
+    result.append(event)
+  return result
+
+
+def events_from_logdir(logdir):
+  """Returns all events in the single eventfile in logdir.
+
+  Args:
+    logdir: The directory in which the single event file is sought.
+
+  Returns:
+    A list of all tf.compat.v1.Event protos from the single event file.
+
+  Raises:
+    AssertionError: If logdir does not contain exactly one file.
+  """
+  assert tf.io.gfile.exists(logdir)
+  files = tf.io.gfile.listdir(logdir)
+  assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files
+  return _events_from_file(os.path.join(logdir, files[0]))
+
+
+class ResNet50Test(tf.test.TestCase):
+
+  def _apply(self, defun=False, execution_mode=None):
+    device, data_format = device_and_data_format()
+    model = resnet50.ResNet50(data_format)
+    if defun:
+      model.call = tf.function(model.call)
+    with tf.device(device), context.execution_mode(execution_mode):
+      images, _ = random_batch(2, data_format)
+      output = model(images, training=False)
+      context.async_wait()
+    self.assertEqual((2, 1000), output.shape)
+
+  def test_apply(self):
+    self._apply(defun=False)
+
+  def test_apply_async(self):
+    self._apply(defun=False, execution_mode=context.ASYNC)
+
+  def test_apply_with_defun(self):
+    self._apply(defun=True)
+
+  def test_apply_with_defun_async(self):
+    self._apply(defun=True, execution_mode=context.ASYNC)
+
+  def test_apply_no_top(self):
+    device, data_format = device_and_data_format()
+    model = resnet50.ResNet50(data_format, include_top=False)
+    with tf.device(device):
+      images, _ = random_batch(2, data_format)
+      output = model(images, training=False)
+    output_shape = ((2, 2048, 1, 1)
+                    if data_format == 'channels_first' else (2, 1, 1, 2048))
+    self.assertEqual(output_shape, output.shape)
+
+  def test_apply_with_pooling(self):
+    device, data_format = device_and_data_format()
+    model = resnet50.ResNet50(data_format, include_top=False, pooling='avg')
+    with tf.device(device):
+      images, _ = random_batch(2, data_format)
+      output = model(images, training=False)
+    self.assertEqual((2, 2048), output.shape)
+
+  def _test_train(self, execution_mode=None):
+    device, data_format = device_and_data_format()
+    model = resnet50.ResNet50(data_format)
+    tf.compat.v2.summary.experimental.set_step(
+        tf.train.get_or_create_global_step())
+    logdir = tempfile.mkdtemp()
+    with tf.compat.v2.summary.create_file_writer(
+        logdir, max_queue=0,
+        name='t0').as_default(), tf.compat.v2.summary.record_if(True):
+      with tf.device(device), context.execution_mode(execution_mode):
+        optimizer = tf.train.GradientDescentOptimizer(0.1)
+        images, labels = random_batch(2, data_format)
+        apply_gradients(model, optimizer,
+                        compute_gradients(model, images, labels))
+        self.assertEqual(320, len(model.variables))
+        context.async_wait()
+    events = events_from_logdir(logdir)
+    self.assertEqual(len(events), 2)
+    self.assertEqual(events[1].summary.value[0].tag, 'loss')
+
+  def test_train(self):
+    self._test_train()
+
+  def test_train_async(self):
+    self._test_train(execution_mode=context.ASYNC)
+
+  def test_no_garbage(self):
+    device, data_format = device_and_data_format()
+    model = resnet50.ResNet50(data_format)
+    optimizer = tf.train.GradientDescentOptimizer(0.1)
+    with tf.device(device):
+      images, labels = random_batch(2, data_format)
+      gc.disable()
+      # Warm up. Note that this first run does create significant amounts of
+      # garbage to be collected. The hope is that this is a build-only effect,
+      # and a subsequent training loop will create nothing which needs to be
+      # collected.
+      apply_gradients(model, optimizer,
+                      compute_gradients(model, images, labels))
+      gc.collect()
+      previous_gc_debug_flags = gc.get_debug()
+      gc.set_debug(gc.DEBUG_SAVEALL)
+      for _ in range(2):
+        # Run twice to ensure that garbage that is created on the first
+        # iteration is no longer accessible.
+        apply_gradients(model, optimizer,
+                        compute_gradients(model, images, labels))
+      gc.collect()
+      # There should be no garbage requiring collection.
+      self.assertEqual(0, len(gc.garbage))
+      gc.set_debug(previous_gc_debug_flags)
+      gc.enable()
+
+
+class MockIterator(object):
+
+  def __init__(self, tensors):
+    self._tensors = [tf.identity(x) for x in tensors]
+
+  def next(self):
+    return self._tensors
+
+
+class ResNet50Benchmarks(tf.test.Benchmark):
+
+  def _train_batch_sizes(self):
+    """Choose batch sizes based on GPU capability."""
+    for device in device_lib.list_local_devices():
+      # TODO(b/141475121): We need some way to check which batch sizes would
+      # work using a public API.
+      if tf.DeviceSpec.from_string(device.name).device_type == 'GPU':
+        # Avoid OOM errors with larger batch sizes, which seem to cause errors
+        # later on even if caught.
+        #
+        # TODO(allenl): Base this on device memory; memory limit information
+        # during the test seems to exclude the amount TensorFlow has allocated,
+        # which isn't useful.
+        if 'K20' in device.physical_device_desc:
+          return (16,)
+        if 'P100' in device.physical_device_desc:
+          return (16, 32, 64)
+
+      if tf.DeviceSpec.from_string(device.name).device_type == 'TPU':
+        return (32,)
+    return (16, 32)
+
+  def _report(self, label, start, num_iters, device, batch_size, data_format,
+              num_replicas=1):
+    avg_time = (time.time() - start) / num_iters
+    dev = tf.DeviceSpec.from_string(device).device_type.lower()
+    replica_str = '' if num_replicas == 1 else 'replicas_%d_' % num_replicas
+    name = '%s_%s_batch_%d_%s%s' % (label, dev, batch_size,
+                                    replica_str, data_format)
+    extras = {'examples_per_sec': (num_replicas * batch_size) / avg_time}
+    self.report_benchmark(
+        iters=num_iters, wall_time=avg_time, name=name, extras=extras)
+
+  def _force_device_sync(self):
+    # If this function is called in the context of a non-CPU device
+    # (e.g., inside a 'with tf.device("/gpu:0")' block)
+    # then this will force a copy from CPU->NON_CPU_DEVICE->CPU,
+    # which forces a sync. This is a roundabout way, yes.
+    tf.constant(1.).cpu()
+
+  def _benchmark_eager_apply(self, label, device_and_format, defun=False,
+                             execution_mode=None):
+    with context.execution_mode(execution_mode):
+      device, data_format = device_and_format
+      model = resnet50.ResNet50(data_format)
+      if defun:
+        model.call = tf.function(model.call)
+      batch_size = 64
+      num_burn = 5
+      num_iters = 30
+      with tf.device(device):
+        images, _ = random_batch(batch_size, data_format)
+        for _ in xrange(num_burn):
+          model(images, training=False).cpu()
+        if execution_mode:
+          context.async_wait()
+        gc.collect()
+        start = time.time()
+        for _ in xrange(num_iters):
+          model(images, training=False).cpu()
+        if execution_mode:
+          context.async_wait()
+        self._report(label, start, num_iters, device, batch_size, data_format)
+
+  def benchmark_eager_apply_sync(self):
+    self._benchmark_eager_apply('eager_apply', device_and_data_format(),
+                                defun=False)
+
+  def benchmark_eager_apply_async(self):
+    self._benchmark_eager_apply(
+        'eager_apply_async', device_and_data_format(), defun=False,
+        execution_mode=context.ASYNC)
+
+  def benchmark_eager_apply_with_defun(self):
+    self._benchmark_eager_apply('eager_apply_with_defun',
+                                device_and_data_format(), defun=True)
+
+  def _benchmark_eager_train(self,
+                             label,
+                             make_iterator,
+                             device_and_format,
+                             defun=False,
+                             execution_mode=None):
+    with context.execution_mode(execution_mode):
+      device, data_format = device_and_format
+      for batch_size in self._train_batch_sizes():
+        (images, labels) = random_batch(batch_size, data_format)
+        model = resnet50.ResNet50(data_format)
+        optimizer = tf.train.GradientDescentOptimizer(0.1)
+        apply_grads = apply_gradients
+        if defun:
+          model.call = tf.function(model.call)
+          apply_grads = tf.function(apply_gradients)
+
+        num_burn = 3
+        num_iters = 10
+        with tf.device(device):
+          iterator = make_iterator((images, labels))
+          for _ in xrange(num_burn):
+            (images, labels) = iterator.next()
+            apply_grads(model, optimizer,
+                        compute_gradients(model, images, labels))
+          if execution_mode:
+            context.async_wait()
+          self._force_device_sync()
+          gc.collect()
+
+          start = time.time()
+          for _ in xrange(num_iters):
+            (images, labels) = iterator.next()
+            apply_grads(model, optimizer,
+                        compute_gradients(model, images, labels))
+          if execution_mode:
+            context.async_wait()
+          self._force_device_sync()
+          self._report(label, start, num_iters, device, batch_size, data_format)
+
+  def benchmark_eager_train_sync(self):
+    self._benchmark_eager_train('eager_train', MockIterator,
+                                device_and_data_format(), defun=False)
+
+  def benchmark_eager_train_async(self):
+    self._benchmark_eager_train(
+        'eager_train_async',
+        MockIterator,
+        device_and_data_format(),
+        defun=False,
+        execution_mode=context.ASYNC)
+
+  def benchmark_eager_train_with_defun(self):
+    self._benchmark_eager_train(
+        'eager_train_with_defun', MockIterator,
+        device_and_data_format(), defun=True)
+
+  def benchmark_eager_train_datasets(self):
+
+    def make_iterator(tensors):
+      with tf.device('/device:CPU:0'):
+        ds = tf.data.Dataset.from_tensors(tensors).repeat()
+      return iter(ds)
+
+    self._benchmark_eager_train(
+        'eager_train_dataset', make_iterator,
+        device_and_data_format(), defun=False)
+
+  def benchmark_eager_train_datasets_with_defun(self):
+
+    def make_iterator(tensors):
+      with tf.device('/device:CPU:0'):
+        ds = tf.data.Dataset.from_tensors(tensors).repeat()
+      return iter(ds)
+
+    self._benchmark_eager_train(
+        'eager_train_dataset_with_defun', make_iterator,
+        device_and_data_format(), defun=True)
+
+
+if __name__ == '__main__':
+  tf.enable_eager_execution()
+  tf.test.main()